# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import math
import pandas as pd
import numpy as np
import wandb
# from edit_distance import SequenceMatcher # Assuming this is available for PER
import editdistance # For PER calculation
import jiwer
import string
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
import os
from pathlib import Path # For constructing paths
from typing import List, Dict, Tuple, Optional 
from transformers.modeling_outputs import BaseModelOutput # Add this import

from misc_architectures import (
        EnhancedDayTransformation, LightFiLMDayTransformation, DayTransformation, HybridDayTransformation, # Day Transforms
        BinnedECoGFeatureExtractor, BinnedAttentionECoGFeatureExtractor, BinnedAttentionECoGFeatureExtractorDownsampling, # Feature Extractors
        ECoGFeatureExtractor, InterpretableECoGExtractor, Unfold # Feature Extractors
    )


# --- Custom Transformer Encoder for efficient intermediate output access ---
class CustomTransformerEncoder(nn.Module):
    """
    A custom Transformer Encoder that is as performant as the built-in version,
    but can also efficiently return the output of a specified intermediate layer.
    """
    def __init__(self, layers: nn.ModuleList, intermediate_layer_idx: int = -1):
        super().__init__()
        self.layers = layers
        self.intermediate_layer_idx = intermediate_layer_idx

    def forward(self, src, src_key_padding_mask=None, mask=None):
        output = src
        intermediate_output = None

        # This loop is encapsulated and will be JIT-compiled for high performance
        for i, mod in enumerate(self.layers):
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
            if i == self.intermediate_layer_idx:
                intermediate_output = output

        # If no intermediate layer was requested, the intermediate is the final output
        if intermediate_output is None:
            intermediate_output = output

        return output, intermediate_output


# --- Constants ---
SOS_IDX = 41
EOS_IDX = 42
PAD_IDX = 0
SPACE_TOKEN_ID = 40 

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)
    
class Permute(nn.Module):
    def __init__(self, dims):
        super().__init__()
        self.dims = dims

    def forward(self, x):
        return x.permute(*self.dims)

class NeuralToPhonemeTransformer(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams) # hparams is expected to be a Namespace or dict

        # --- Noise Augmentation Params ---
        self.white_noise_sd = self.hparams.get('white_noise_sd', 0.0)
        self.constant_offset_sd = self.hparams.get('constant_offset_sd', 0.0)

        # --- SpecAugment Params ---
        self.time_masking_prob = self.hparams.get('time_masking_prob', 0.0)
        self.time_mask_max_len = self.hparams.get('time_mask_max_len', 0)
        self.time_mask_max_proportion = self.hparams.get('time_mask_max_proportion', 0.0)
        self.channel_masking_prob = self.hparams.get('channel_masking_prob', 0.0)
        self.channel_mask_max_electrodes = self.hparams.get('channel_mask_max_electrodes', 0)

        # --- Downsampling Configuration ---
        self.downsampling_strategy = self.hparams.get('downsampling_strategy', 'none')
        if self.downsampling_strategy == 'conv':
            self.downsample_factor = self.hparams.get('downsampling_factor', 1)
        elif self.downsampling_strategy == 'unfold':
            self.downsample_factor = self.hparams.get('downsampling_factor', 1)
        else:
            self.downsample_factor = 1
        rank_zero_info(f"Downsampling strategy: {self.downsampling_strategy}, Factor: {self.downsample_factor}")


                # --- ECoG Feature Extractor ---
        ecog_n_channels = self.hparams.ecog_n_channels
        d_model = self.hparams.d_model
        feature_extraction_type = self.hparams.get('feature_extractor_type', 'FC') # Default to linear

        if feature_extraction_type == "unfold":
            unfolded_dim = ecog_n_channels * self.downsample_factor
            projection_layers = [Unfold(self.downsample_factor)]
            if unfolded_dim != d_model:
                projection_layers.append(nn.Linear(unfolded_dim, d_model))
            self.input_projection = nn.Sequential(*projection_layers)
        elif feature_extraction_type == "FC":
            self.input_projection = nn.Linear(ecog_n_channels, d_model)
        elif feature_extraction_type == "simple_conv": # Matches 'Simple_conv'
            self.input_projection = nn.Sequential(
                nn.Conv1d(ecog_n_channels, d_model, kernel_size=3, padding=1),
                Permute([0, 2, 1]) # To make it [B, L, d_model]
            )
        elif feature_extraction_type == "binned_conv": # Matches 'Binned_conv'
            self.input_projection = BinnedECoGFeatureExtractor(ecog_n_channels, d_model)
        elif feature_extraction_type == "binned_attention_conv":
            self.input_projection = BinnedAttentionECoGFeatureExtractor(ecog_n_channels, d_model)
        elif feature_extraction_type == "binned_attention_conv_downsample":
            # Note: This is now controlled by downsampling_strategy='conv'
            kernel_size = self.hparams.get('feature_extractor_kernel_size', 3)
            self.input_projection = BinnedAttentionECoGFeatureExtractorDownsampling(
                n_channels=ecog_n_channels, 
                feature_dim=d_model, 
                downsample_factor=self.downsample_factor,
                kernel_size=kernel_size
            )
        elif feature_extraction_type == "deep_conv": # Matches 'Deep_conv'
            self.input_projection = ECoGFeatureExtractor(ecog_n_channels, d_model)
        elif feature_extraction_type == "interpretable_conv":
            self.input_projection = InterpretableECoGExtractor(ecog_n_channels, d_model)
        else:
            raise ValueError(f"Invalid feature_extraction_type: {feature_extraction_type}")
        
        # --- Optional Day Transformation ---
        use_day_transform_hparam = self.hparams.get('use_day_transform', False)
        day_transform_type = self.hparams.get('day_transform_type', 'FC') # Default to simple FC
        num_days = self.hparams.get('num_days', 24) # Default to 24 if not specified

        self.day_transform = None
        if use_day_transform_hparam:
            if day_transform_type == "FiLM":
                self.day_transform = EnhancedDayTransformation(nDays=num_days, neural_dim=d_model)
            elif day_transform_type == "LightFiLM":
                self.day_transform = LightFiLMDayTransformation(nDays=num_days, neural_dim=d_model)
            elif day_transform_type == "FC": # Fully Connected DayTransformation
                self.day_transform = DayTransformation(nDays=num_days, neural_dim=d_model)
            elif day_transform_type == "Hybrid":
                self.day_transform = HybridDayTransformation(nDays=num_days, neural_dim=d_model)
            else:
                print(f"Warning: Unknown day_transform_type '{day_transform_type}'. No day transformation will be applied.")
                self.day_transform = None
        
        # Positional Encodings
        self.pos_encoder = PositionalEncoding(d_model, self.hparams.dropout)
        self.pos_decoder = PositionalEncoding(d_model, self.hparams.dropout) # Separate instance for decoder

        # --- Transformer Encoder ---
        self.encoder_implementation = self.hparams.get('encoder_implementation', 'custom_class')
        aux_head_input_layer_idx = self.hparams.get('aux_head_input_layer_idx', -1)
        
        # FINAL ROBUST DESIGN WITH SELECTABLE IMPLEMENTATION
        if self.encoder_implementation == 'custom_class' and aux_head_input_layer_idx != -1 and \
           (0 <= aux_head_input_layer_idx < self.hparams.num_encoder_layers):
            # Use the performant custom class ONLY when explicitly requested and configured
            full_encoder_layer_stack = nn.ModuleList([
                nn.TransformerEncoderLayer(
                    d_model=self.hparams.d_model, nhead=self.hparams.n_head,
                    dim_feedforward=self.hparams.dim_feedforward, dropout=self.hparams.dropout,
                    activation=self.hparams.get('transformer_activation', 'relu'), batch_first=True
                ) for _ in range(self.hparams.num_encoder_layers)
            ])
            self.encoder = CustomTransformerEncoder(
                layers=full_encoder_layer_stack,
                intermediate_layer_idx=aux_head_input_layer_idx
            )
            print(f"Using CustomTransformerEncoder. Aux heads will tap layer {aux_head_input_layer_idx}.")
        else:
            # In all other cases (looping implementation or default), use the standard encoder.
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=self.hparams.d_model, nhead=self.hparams.n_head,
                dim_feedforward=self.hparams.dim_feedforward, dropout=self.hparams.dropout,
                activation=self.hparams.get('transformer_activation', 'relu'), batch_first=True
            )
            self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=self.hparams.num_encoder_layers)
            if self.encoder_implementation == 'loop':
                print("Using standard nn.TransformerEncoder with looping forward pass.")
            else:
                print("Using standard nn.TransformerEncoder. Aux heads will receive final encoder output.")


        # --- Phoneme Embedding ---
        self.phoneme_embedding = nn.Embedding(self.hparams.vocab_size, self.hparams.d_model, padding_idx=PAD_IDX)

        # --- Transformer Decoder ---
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.hparams.d_model,
            nhead=self.hparams.n_head,
            dim_feedforward=self.hparams.dim_feedforward,
            dropout=self.hparams.dropout,
            activation=self.hparams.get('transformer_activation', 'relu'),
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=self.hparams.num_decoder_layers)

        # --- Main Output Projection ---
        self.output_projection = nn.Linear(self.hparams.d_model, self.hparams.vocab_size)

        # --- Auxiliary Prediction Heads (conditionally initialized and downsample-aware) ---
        self.aux_heads = nn.ModuleDict()
        
        if self.hparams.get('train_mfcc_aux', False):
            # The output dim of the aux head depends on the downsampling strategy
            final_mfcc_dim = self.hparams.num_mfcc_features
            if self.downsample_factor > 1:
                # When downsampling, we unfold the target, increasing its feature dim
                final_mfcc_dim *= self.downsample_factor

            mfcc_head_type = self.hparams.get('mfcc_aux_head_type', 'linear')
            rank_zero_info(f"Using '{mfcc_head_type}' for MFCC auxiliary head.")

            if mfcc_head_type == 'mlp':
                self.aux_heads['mfcc'] = nn.Sequential(
                    nn.Linear(self.hparams.d_model, self.hparams.d_model // 2),
                    nn.ReLU(),
                    nn.Dropout(self.hparams.dropout),
                    nn.Linear(self.hparams.d_model // 2, final_mfcc_dim)
                )
            elif mfcc_head_type == 'linear':
                self.aux_heads['mfcc'] = nn.Linear(self.hparams.d_model, final_mfcc_dim)
            else:
                raise ValueError(f"Unknown mfcc_aux_head_type: {mfcc_head_type}")

        if self.hparams.get('train_envelope_aux', False):
            # The output dim for envelope also scales with the downsample factor
            final_envelope_dim = 1
            if self.downsample_factor > 1:
                final_envelope_dim *= self.downsample_factor

            self.aux_heads['envelope'] = nn.Sequential(
                nn.Linear(self.hparams.d_model, self.hparams.d_model // 2),
                nn.ReLU(),
                nn.Dropout(self.hparams.dropout),
                nn.Linear(self.hparams.d_model // 2, final_envelope_dim)
            )
        
        # For sequence-level predictions, we might need a pooling layer
        # For simplicity, these heads will operate on the first token's features or a mean pool
        if self.hparams.get('train_phoneme_len_aux', False):
            # self.aux_heads['phoneme_len'] = nn.Linear(self.hparams.d_model, 1)
            self.aux_heads['phoneme_len'] = nn.Sequential(
                nn.Linear(self.hparams.d_model, self.hparams.d_model // 2),
                nn.ReLU(),
                nn.Dropout(self.hparams.dropout),
                nn.Linear(self.hparams.d_model // 2, 1)
            )
        if self.hparams.get('train_word_count_aux', False):
            self.aux_heads['word_count'] = nn.Sequential(
                nn.Linear(self.hparams.d_model, self.hparams.d_model // 2),
                nn.ReLU(),
                nn.Dropout(self.hparams.dropout),
                nn.Linear(self.hparams.d_model // 2, 1)
            )
            
        # To store validation outputs for detailed logging
        self.validation_step_outputs = []

         # Create a simple idx_to_token mapping if PHONEME_MAP is provided
        # This is a fallback if a pre-processed idx_to_token list isn't passed.
        self.idx_to_token_map = getattr(self.hparams, 'PHONEME_MAP', {})
        if not self.idx_to_token_map:
            print("Warning: self.hparams.PHONEME_MAP is not set. Phoneme string conversion will use raw IDs.")


        # --- Layers for Augmenting Encoder/Decoder Inputs (Initialized Unconditionally) ---
        
        # For Encoder Output Augmentation (MFCCs)
        if not hasattr(self.hparams, 'num_mfcc_features'):
            # This is a critical hparam if MFCC augmentation might ever be used.
            # Defaulting or raising error. For unconditional init, we need a value.
            print("Warning: hparams.num_mfcc_features not found. Defaulting to 14 for layer definition.")
            self.hparams.num_mfcc_features = self.hparams.get('num_mfcc_features', 14) # Ensure it exists
            
        self.encoder_feature_fusion_projection = nn.Linear(
            self.hparams.d_model + self.hparams.num_mfcc_features,
            self.hparams.d_model
        )

        # For Decoder Input Augmentation (Global Features)
        if not hasattr(self.hparams, 'global_feature_emb_dim'):
            print("Warning: hparams.global_feature_emb_dim not found. Defaulting to 32 for layer definition.")
            self.hparams.global_feature_emb_dim = self.hparams.get('global_feature_emb_dim', 32)

        self.len_embedder = nn.Linear(1, self.hparams.global_feature_emb_dim)
        self.word_count_embedder = nn.Linear(1, self.hparams.global_feature_emb_dim)
        
        # Calculate the maximum possible number of global features that could be active
        # This is for defining the decoder_input_fusion_projection layer's input size.
        # The actual number of concatenated features will be dynamic in the forward pass.
        max_possible_global_features = 2 # (phoneme_len, word_count)
        
        self.decoder_input_fusion_projection = nn.Linear(
            self.hparams.d_model + (max_possible_global_features * self.hparams.global_feature_emb_dim),
            self.hparams.d_model # Project back to d_model for decoder layers
        )
        
        if self.hparams.get('augment_encoder_with_mfcc', False):
            print("Encoder augmentation with MFCC is configured to be active.")
        if self.hparams.get('augment_decoder_with_phoneme_len', False):
            print("Decoder augmentation with phoneme length is configured to be active.")
        if self.hparams.get('augment_decoder_with_word_count', False):
            print("Decoder augmentation with word count is configured to be active.")

        # --- BART Text Decoder Head (New) ---
        self.bart_text_decoder_active = self.hparams.get('train_bart_text_decoder', False)
        self.bart_model = None
        self.bart_tokenizer = None

        # --- Whisper Text Decoder Head (New) ---
        self.whisper_text_decoder_active = self.hparams.get('train_whisper_text_decoder', False)
        self.whisper_model = None
        self.whisper_tokenizer = None
        self.whisper_processor = None # Whisper uses a processor

        # if self.bart_text_decoder_active and self.whisper_text_decoder_active:
        #     raise ValueError("Both BART and Whisper decoders cannot be active at the same time. Please choose one.")

        if self.bart_text_decoder_active:
            try:
                from transformers import BartForConditionalGeneration, BartTokenizer
                bart_type = self.hparams.get('bart_type', 'bart-base')
                print(f"Initializing BART model and tokenizer (facebook/{bart_type})...")
                self.bart_model = BartForConditionalGeneration.from_pretrained(f"facebook/{bart_type}")
                self.bart_tokenizer = BartTokenizer.from_pretrained(f"facebook/{bart_type}")
                
                if self.hparams.d_model != self.bart_model.config.hidden_size:
                    print(f"Warning: d_model ({self.hparams.d_model}) != BART hidden_size ({self.bart_model.config.hidden_size}). Adding projection layer.")
                    self.ecog_to_bart_hidden_projection = nn.Sequential(
                        nn.LayerNorm(self.hparams.d_model),
                        nn.Linear(self.hparams.d_model, self.bart_model.config.hidden_size)
                    )
                else:
                    self.ecog_to_bart_hidden_projection = nn.Identity()
                print("BART model and tokenizer initialized.")
            except ImportError:
                print("ERROR: `transformers` library not found. BART decoder head cannot be initialized.")
                self.bart_text_decoder_active = False
            except Exception as e:
                print(f"Error initializing BART model/tokenizer: {e}")
                self.bart_text_decoder_active = False
        
        if self.whisper_text_decoder_active:
            try:
                from transformers import WhisperForConditionalGeneration, WhisperTokenizer, WhisperProcessor
                whisper_type = self.hparams.get('whisper_type', 'whisper-base')
                print(f"Initializing Whisper model and processor (openai/{whisper_type})...")
                self.whisper_model = WhisperForConditionalGeneration.from_pretrained(f"openai/{whisper_type}")
                self.whisper_processor = WhisperProcessor.from_pretrained(f"openai/{whisper_type}")
                self.whisper_tokenizer = self.whisper_processor.tokenizer # For easy access

                if self.hparams.d_model != self.whisper_model.config.hidden_size:
                    print(f"Warning: d_model ({self.hparams.d_model}) != Whisper hidden_size ({self.whisper_model.config.hidden_size}). Adding projection layer.")
                    self.ecog_to_whisper_hidden_projection = nn.Sequential(
                        nn.LayerNorm(self.hparams.d_model),
                        nn.Linear(self.hparams.d_model, self.whisper_model.config.hidden_size)
                    )
                else:
                    self.ecog_to_whisper_hidden_projection = nn.Identity()
                print("Whisper model and processor initialized.")
            except ImportError:
                print("ERROR: `transformers` library not found. Whisper decoder head cannot be initialized.")
                self.whisper_text_decoder_active = False
            except Exception as e:
                print(f"Error initializing Whisper model/processor: {e}")
                self.whisper_text_decoder_active = False

        # --- Freezing logic based on config ---
        self._apply_freezing_from_hparams()

    def _apply_freezing_from_hparams(self):
        # Freeze encoder if specified
        if getattr(self.hparams, "freeze_encoder_in_joint", False):
            print("Encoder parts are FROZEN for joint training.")
            for param in self.encoder.parameters():
                param.requires_grad = False

        # if getattr(self.hparams, "freeze_encoder_in_joint", False) and \
        #    self.hparams.get('training_stage', '') != 'encoder_only': # only freeze if not in encoder_only stage
        #     print("Encoder parts are FROZEN for joint training.")
        #     if hasattr(self, 'encoder'):
        #         for param in self.encoder.parameters():
        #             param.requires_grad = False
        #     if hasattr(self, 'input_projection'):
        #         print(f"Freezing parameters of: input_projection")
        #         for param in self.input_projection.parameters():
        #             param.requires_grad = False
        #     if hasattr(self, 'day_transform') and self.day_transform is not None:
        #          print(f"Freezing parameters of: day_transform")
        #          for param in self.day_transform.parameters():
        #                 param.requires_grad = False

        # Freeze decoder if specified
        if getattr(self.hparams, "freeze_decoder_in_joint", False) and \
            self.hparams.get('training_stage', '') != 'decoder_only': # Example condition
            print("Decoder parts are FROZEN for joint training.")
            if hasattr(self, 'decoder'):
                for param in self.decoder.parameters():
                    param.requires_grad = False
            if hasattr(self, 'phoneme_embedding'):
                 print(f"Freezing parameters of: phoneme_embedding")
                 for param in self.phoneme_embedding.parameters():
                        param.requires_grad = False
            if hasattr(self, 'output_projection'):
                print(f"Freezing parameters of: output_projection")
                for param in self.output_projection.parameters():
                    param.requires_grad = False
            # Consider freezing pos_decoder as well if it's part of the decoder path
            if hasattr(self, 'pos_decoder'):
                print(f"Freezing parameters of: pos_decoder")
                for param in self.pos_decoder.parameters():
                    param.requires_grad = False

        # Freeze input_projection if specified
        if getattr(self.hparams, "freeze_input_projection_in_joint", False):
            print("Input projection is FROZEN for joint training.")
            for param in self.input_projection.parameters():
                param.requires_grad = False

        # Freeze day_transform if specified
        if getattr(self.hparams, "freeze_day_transform_in_joint", False):
            print("Day transform is FROZEN for joint training.")
            for param in self.day_transform.parameters():
                param.requires_grad = False

        # Freeze augmentation-related projection layers if specified
        if getattr(self.hparams, "freeze_augmentation_projections_in_joint", False):
            print("Augmentation projection layers are FROZEN for joint training.")
            if hasattr(self, 'encoder_feature_fusion_projection'):
                for param in self.encoder_feature_fusion_projection.parameters():
                    param.requires_grad = False
            if hasattr(self, 'len_embedder'):
                for param in self.len_embedder.parameters():
                    param.requires_grad = False
            if hasattr(self, 'word_count_embedder'):
                for param in self.word_count_embedder.parameters():
                    param.requires_grad = False
            if hasattr(self, 'decoder_input_fusion_projection'):
                for param in self.decoder_input_fusion_projection.parameters():
                    param.requires_grad = False
            if hasattr(self, 'ecog_to_bart_hidden_projection') and isinstance(self.ecog_to_bart_hidden_projection, nn.Linear):
                 for param in self.ecog_to_bart_hidden_projection.parameters():
                    param.requires_grad = False

        # --- BART Freezing Logic (New) ---
        if self.bart_text_decoder_active and self.bart_model is not None:
            bart_freezing_strategy = self.hparams.get('bart_freezing_strategy', 'none') # 'none', 'decoder_only', 'cross_attn_only', 'freeze_all_bart', 'freeze_early_layers
            print(f"Applying BART freezing strategy: {bart_freezing_strategy}")

            if bart_freezing_strategy == 'freeze_all_bart':
                print("BART model: ALL parameters FROZEN.")
                for param in self.bart_model.parameters():
                    param.requires_grad = False
            
            elif bart_freezing_strategy == 'decoder_only':
                print("BART model: Training DECODER ONLY (encoder and embeddings frozen).")
                # Freeze BART's encoder
                if hasattr(self.bart_model, 'model') and hasattr(self.bart_model.model, 'encoder'):
                    for param in self.bart_model.model.encoder.parameters():
                        param.requires_grad = False
                # Freeze BART's embeddings (shared for encoder/decoder)
                # if hasattr(self.bart_model, 'model') and hasattr(self.bart_model.model, 'shared'):
                #      for param in self.bart_model.model.shared.parameters():
                #         param.requires_grad = False
                # Ensure decoder is trainable (it should be by default unless 'freeze_all_bart')
                if hasattr(self.bart_model, 'model') and hasattr(self.bart_model.model, 'decoder'):
                    for param in self.bart_model.model.decoder.parameters():
                        param.requires_grad = True
                if hasattr(self.bart_model, 'lm_head'): # BART's output projection
                    for param in self.bart_model.lm_head.parameters():
                        param.requires_grad = True

            elif bart_freezing_strategy == 'freeze_first_3_layers':
                print("BART model: Freezing early decoder layers (0–2), training top decoder layers, encoder, and lm_head.")
                # Freeze encoder
                if hasattr(self.bart_model, 'model') and hasattr(self.bart_model.model, 'encoder'):
                    for param in self.bart_model.model.encoder.parameters():
                        param.requires_grad = False
                # Freeze early decoder layers (e.g., 0–2)
                if hasattr(self.bart_model, 'model') and hasattr(self.bart_model.model, 'decoder'):
                    decoder_layers = self.bart_model.model.decoder.layers
                    for i in range(len(decoder_layers)):
                        if i <= 2:
                            for param in decoder_layers[i].parameters():
                                param.requires_grad = False
                        else:
                            for param in decoder_layers[i].parameters():
                                param.requires_grad = True
                # Make sure lm_head is trainable
                if hasattr(self.bart_model, 'lm_head'):
                    for param in self.bart_model.lm_head.parameters():
                        param.requires_grad = True

            elif bart_freezing_strategy == 'freeze_first_8_layers':
                print("BART model: Freezing early decoder layers (0–7), training top decoder layers, encoder, and lm_head.")
                # Freeze encoder
                if hasattr(self.bart_model, 'model') and hasattr(self.bart_model.model, 'encoder'):
                    for param in self.bart_model.model.encoder.parameters():
                        param.requires_grad = False

                if hasattr(self.bart_model, 'model') and hasattr(self.bart_model.model, 'decoder'):
                    decoder_layers = self.bart_model.model.decoder.layers
                    for i in range(len(decoder_layers)):
                        if i <= 7:
                            for param in decoder_layers[i].parameters():
                                param.requires_grad = False
                        else:
                            for param in decoder_layers[i].parameters():
                                param.requires_grad = True
                # Make sure lm_head is trainable
                if hasattr(self.bart_model, 'lm_head'):
                    for param in self.bart_model.lm_head.parameters():
                        param.requires_grad = True


            elif bart_freezing_strategy == 'cross_attn_only':
                print("BART model: Training CROSS-ATTENTION layers in decoder ONLY.")
                for name, param in self.bart_model.named_parameters():
                    param.requires_grad = False # Freeze all by default
                
                if hasattr(self.bart_model, 'model') and hasattr(self.bart_model.model, 'decoder'):
                    for layer in self.bart_model.model.decoder.layers:
                        if hasattr(layer, 'encoder_attn'):
                            print(f"  Unfreezing cross-attention: {layer}.encoder_attn")
                            for param in layer.encoder_attn.parameters():
                                param.requires_grad = True
                        # Also typically unfreeze the layer norm following cross-attention
                        if hasattr(layer, 'encoder_attn_layer_norm'):
                            print(f"  Unfreezing layer_norm for cross-attention: {layer}.encoder_attn_layer_norm")
                            for param in layer.encoder_attn_layer_norm.parameters():
                                param.requires_grad = True
                # The lm_head should also be trainable if we expect the model to learn predictions
                if hasattr(self.bart_model, 'lm_head'):
                    print("  Unfreezing lm_head for BART.")
                    for param in self.bart_model.lm_head.parameters():
                        param.requires_grad = True
            
            elif bart_freezing_strategy == 'none':
                print("BART model: ALL parameters TRAINABLE.")
                for param in self.bart_model.parameters():
                    param.requires_grad = True
            else:
                print(f"Warning: Unknown bart_freezing_strategy '{bart_freezing_strategy}'. BART parameters will not be specifically frozen by this strategy.")
        
        # --- Whisper Freezing Logic (New) ---
        if self.whisper_text_decoder_active and self.whisper_model is not None:
            whisper_freezing_strategy = self.hparams.get('whisper_freezing_strategy', 'none')
            print(f"Applying Whisper freezing strategy: {whisper_freezing_strategy}")

            if whisper_freezing_strategy == 'freeze_all_whisper':
                print("Whisper model: ALL parameters FROZEN.")
                for param in self.whisper_model.parameters():
                    param.requires_grad = False
            
            elif whisper_freezing_strategy == 'decoder_only':
                print("Whisper model: Training DECODER ONLY (encoder frozen).")
                if hasattr(self.whisper_model, 'model') and hasattr(self.whisper_model.model, 'encoder'):
                    for param in self.whisper_model.model.encoder.parameters():
                        param.requires_grad = False
                # Ensure decoder is trainable
                if hasattr(self.whisper_model, 'model') and hasattr(self.whisper_model.model, 'decoder'):
                    for param in self.whisper_model.model.decoder.parameters():
                        param.requires_grad = True
                if hasattr(self.whisper_model, 'proj_out'): # Whisper's output projection
                    for param in self.whisper_model.proj_out.parameters():
                        param.requires_grad = True

            elif whisper_freezing_strategy == 'cross_attn_only':
                print("Whisper model: Training CROSS-ATTENTION layers in decoder ONLY.")
                for name, param in self.whisper_model.named_parameters():
                    param.requires_grad = False # Freeze all by default
                
                if hasattr(self.whisper_model, 'model') and hasattr(self.whisper_model.model, 'decoder'):
                    for layer in self.whisper_model.model.decoder.layers:
                        if hasattr(layer, 'encoder_attn'):
                            print(f"  Unfreezing cross-attention: {layer}.encoder_attn")
                            for param in layer.encoder_attn.parameters():
                                param.requires_grad = True
                        # Also typically unfreeze the layer norm following cross-attention
                        if hasattr(layer, 'encoder_attn_layer_norm'):
                            print(f"  Unfreezing layer_norm for cross-attention: {layer}.encoder_attn_layer_norm")
                            for param in layer.encoder_attn_layer_norm.parameters():
                                param.requires_grad = True
                # The proj_out head should also be trainable if we expect the model to learn predictions
                if hasattr(self.whisper_model, 'proj_out'):
                    print("  Unfreezing proj_out for Whisper.")
                    for param in self.whisper_model.proj_out.parameters():
                        param.requires_grad = True
            
            elif whisper_freezing_strategy == 'none':
                print("Whisper model: ALL parameters TRAINABLE.")
                for param in self.whisper_model.parameters():
                    param.requires_grad = True
            else:
                print(f"Warning: Unknown whisper_freezing_strategy '{whisper_freezing_strategy}'. Whisper parameters will not be specifically frozen by this strategy.")

    def _get_downsampled_lens(self, original_lens):
        if self.downsample_factor <= 1:
            return original_lens
        # Use ceil to handle non-divisible lengths correctly
        return torch.ceil(original_lens.float() / self.downsample_factor).long()

    def _generate_padding_mask(self, lengths, max_len):
        mask = torch.arange(max_len, device=lengths.device)[None, :] >= lengths[:, None]
        return mask

    def _encoder_forward(self, src_ecog, src_ecog_lens, days=None):
        """
        Processes the ECoG signal through the encoder.
        Returns the final encoder output for decoders, and a designated
        intermediate encoder output for auxiliary prediction heads.
        """
        # src_ecog: [B, T_ecog, N_channels]
        # src_ecog_lens: [B]
        
        src_proj = self.input_projection(src_ecog)  # [B, T_ecog, d_model]
        
        if self.day_transform and days is not None:
            src_proj = self.day_transform(src_proj, days)

        src_emb = self.pos_encoder(src_proj * math.sqrt(self.hparams.d_model)) # Scale before pos encoding
        
        max_encoder_len = src_emb.size(1)
        src_padding_mask = self._generate_padding_mask(src_ecog_lens, max_encoder_len) # [B, T_ecog]

        # --- Selectable Encoder Forward Pass ---
        if self.hparams.get('encoder_implementation', 'custom_class') == 'custom_class':
            # --- High-performance path using CustomTransformerEncoder or standard nn.TransformerEncoder ---
            if isinstance(self.encoder, CustomTransformerEncoder):
                memory, intermediate_output_for_aux = self.encoder(
                    src_emb, src_key_padding_mask=src_padding_mask
                )
            else: # It's a standard nn.TransformerEncoder
                memory = self.encoder(src_emb, src_key_padding_mask=src_padding_mask)
                intermediate_output_for_aux = memory # Fallback to final output
        else:
            # --- Explicit looping path for debugging/comparison ---
            aux_head_input_layer_idx = self.hparams.get('aux_head_input_layer_idx', -1)
            
            encoder_output_current = src_emb
            intermediate_output_for_aux = None
            memory = None

            num_encoder_layers = len(self.encoder.layers)
            if not (0 <= aux_head_input_layer_idx < num_encoder_layers):
                memory = self.encoder(encoder_output_current, src_key_padding_mask=src_padding_mask)
                intermediate_output_for_aux = memory
            else:
                for i, layer in enumerate(self.encoder.layers):
                    encoder_output_current = layer(encoder_output_current, src_key_padding_mask=src_padding_mask)
                    if i == aux_head_input_layer_idx:
                        intermediate_output_for_aux = encoder_output_current
                memory = encoder_output_current

        # --- Zero out encoder outputs corresponding to padding ---
        if src_padding_mask is not None:
            memory = memory.masked_fill(src_padding_mask.unsqueeze(-1), 0.0)
            if intermediate_output_for_aux is not None:
                intermediate_output_for_aux = intermediate_output_for_aux.masked_fill(src_padding_mask.unsqueeze(-1), 0.0)

        return memory, intermediate_output_for_aux, src_padding_mask

    def _decoder_forward(self, tgt_phonemes_input, memory, memory_key_padding_mask, tgt_phoneme_lens,
                         len_condition_vec=None, wc_condition_vec=None): # Added new params
        # tgt_phonemes_input: [B, T_phoneme_in] (e.g., SOS + phonemes)
        # memory: [B, T_ecog, d_model] (potentially augmented)
        # memory_key_padding_mask: [B, T_ecog]
        # tgt_phoneme_lens: [B] (length of tgt_phonemes_input)
        # len_condition_vec: [B, global_feature_emb_dim] or None
        # wc_condition_vec: [B, global_feature_emb_dim] or None

        tgt_emb = self.phoneme_embedding(tgt_phonemes_input) # [B, T_phoneme_in, phoneme_emb_dim]

        # Augment decoder input embeddings if global features are provided and augmentation is enabled
        if self.hparams.get('augment_decoder_with_phoneme_len', False) or \
           self.hparams.get('augment_decoder_with_word_count', False):
            # Only call augmentation if at least one global feature type is enabled for augmentation
            # and at least one relevant condition vector is actually provided.
            if len_condition_vec is not None or wc_condition_vec is not None:
                 tgt_emb = self._augment_decoder_input_embeddings(tgt_emb, len_condition_vec, wc_condition_vec)
        
        tgt_emb = self.pos_decoder(tgt_emb * math.sqrt(self.hparams.d_model))

        max_tgt_len = tgt_phonemes_input.size(1)
        tgt_padding_mask = self._generate_padding_mask(tgt_phoneme_lens, max_tgt_len) # [B, T_phoneme_in]
        
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(max_tgt_len, device=tgt_emb.device).bool() # [T_phoneme_in, T_phoneme_in]

        decoder_output = self.decoder(
            tgt_emb,
            memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        ) # [B, T_phoneme_in, d_model]
        
        logits = self.output_projection(decoder_output) # [B, T_phoneme_in, vocab_size]
        return logits

    def _calculate_auxiliary_losses(self, encoder_outputs, batch, downsampled_ecog_lens):
        # encoder_outputs: [B, T_enc, d_model]
        # downsampled_ecog_lens: [B] (lengths of the downsampled sequence)
        
        aux_losses_dict = {}
        aux_predictions_dict = {} # To store raw predictions from aux heads
        aux_targets_dict = {}
        total_aux_loss = torch.tensor(0.0, device=encoder_outputs.device)
        
        # Per-timestep auxiliary tasks
        if 'mfcc' in self.aux_heads:
            mfcc_preds_raw = self.aux_heads['mfcc'](encoder_outputs) # [B, T_enc, num_mfcc * F]
            aux_predictions_dict['mfcc'] = mfcc_preds_raw # Store raw predictions
            
            mfcc_targets = batch['mfcc']
            # Slice the target tensor to use the number of features specified in hparams
            mfcc_targets = mfcc_targets[:, :, :self.hparams.num_mfcc_features]
            
            if self.downsample_factor > 1:
                # Unfold the target to match the downsampled prediction space
                T_target = mfcc_targets.size(1)
                num_frames = T_target // self.downsample_factor
                mfcc_targets = mfcc_targets[:, :num_frames * self.downsample_factor, :]
                mfcc_targets = mfcc_targets.contiguous().view(mfcc_targets.size(0), num_frames, -1)

            max_len_mfcc = min(mfcc_preds_raw.size(1), mfcc_targets.size(1))
            mfcc_targets_for_loss = mfcc_targets[:, :max_len_mfcc, :]
            mfcc_preds_for_loss = mfcc_preds_raw[:, :max_len_mfcc, :]
            
            mask_mfcc = (~self._generate_padding_mask(downsampled_ecog_lens, max_len_mfcc)).unsqueeze(-1).float()
            
            mfcc_loss = F.mse_loss(mfcc_preds_for_loss * mask_mfcc, mfcc_targets_for_loss * mask_mfcc, reduction='sum') / mask_mfcc.sum().clamp(min=1e-9)
            aux_losses_dict['mfcc_loss'] = mfcc_loss
            total_aux_loss += self.hparams.get('mfcc_loss_weight', 1.0) * mfcc_loss
            aux_targets_dict['mfcc'] = mfcc_targets.detach() # Store downsampled target

        if 'envelope' in self.aux_heads:
            env_preds_raw = self.aux_heads['envelope'](encoder_outputs) # [B, T_enc, 1 * F]
            aux_predictions_dict['envelope'] = env_preds_raw # Store raw predictions

            env_targets = batch['audioEnvelope']
            if env_targets.ndim == 2: env_targets = env_targets.unsqueeze(-1) # Ensure [B, T, 1]

            if self.downsample_factor > 1:
                T_target = env_targets.size(1)
                num_frames = T_target // self.downsample_factor
                env_targets = env_targets[:, :num_frames * self.downsample_factor, :]
                env_targets = env_targets.contiguous().view(env_targets.size(0), num_frames, -1)

            max_len_env = min(env_preds_raw.size(1), env_targets.size(1))
            env_targets_for_loss = env_targets[:, :max_len_env, :]
            env_preds_for_loss = env_preds_raw[:, :max_len_env, :]

            mask_env = (~self._generate_padding_mask(downsampled_ecog_lens, max_len_env)).unsqueeze(-1).float()
            env_loss = F.mse_loss(env_preds_for_loss * mask_env, env_targets_for_loss * mask_env, reduction='sum') / mask_env.sum().clamp(min=1e-9)
            aux_losses_dict['envelope_loss'] = env_loss
            total_aux_loss += self.hparams.get('envelope_loss_weight', 1.0) * env_loss
            
            aux_targets_dict['envelope'] = env_targets.detach() # Store downsampled target

        # Sequence-level auxiliary tasks
        pooled_encoder_output = None
        needs_pooling = ('phoneme_len' in self.aux_heads) or ('word_count' in self.aux_heads)

        if needs_pooling:
            # Here, we use the downsampled lengths for correct pooling
            pool_mask = (~self._generate_padding_mask(downsampled_ecog_lens, encoder_outputs.size(1))).unsqueeze(-1).float()
            pooled_encoder_output = (encoder_outputs * pool_mask).sum(dim=1) / pool_mask.sum(dim=1).clamp(min=1e-9)

        if 'phoneme_len' in self.aux_heads and pooled_encoder_output is not None:
            scaling_factor = self.hparams.get('phoneme_len_scaling_factor', 30.0) # Use hparam or default
            phoneme_len_preds_raw = self.aux_heads['phoneme_len'](pooled_encoder_output) # [B, 1]
            aux_predictions_dict['phoneme_len'] = phoneme_len_preds_raw * scaling_factor # Store unnormalized prediction for potential use in SG guidance
            
            phoneme_len_targets = batch['phoneSeqLen'].float().unsqueeze(-1) / scaling_factor
            phoneme_len_loss = F.mse_loss(phoneme_len_preds_raw, phoneme_len_targets) # Loss on scaled preds
            
            aux_losses_dict['phoneme_len_loss'] = phoneme_len_loss
            total_aux_loss += self.hparams.get('phoneme_len_loss_weight', 1.0) * phoneme_len_loss
            aux_targets_dict['phoneme_len'] = batch['phoneSeqLen'].float().unsqueeze(-1).detach()


        if 'word_count' in self.aux_heads:
            scaling_factor = self.hparams.get('word_count_scaling_factor', 10.0) # Use hparam or default
            word_count_preds_raw = self.aux_heads['word_count'](pooled_encoder_output) # [B, 1]
            aux_predictions_dict['word_count'] = word_count_preds_raw * scaling_factor # Store unnormalized prediction
            
            word_count_targets = batch['numWords'].float().unsqueeze(-1) / scaling_factor
            word_count_loss = F.mse_loss(word_count_preds_raw, word_count_targets) # Loss on scaled preds

            aux_losses_dict['word_count_loss'] = word_count_loss
            total_aux_loss += self.hparams.get('word_count_loss_weight', 1.0) * word_count_loss
            aux_targets_dict['word_count'] = batch['numWords'].float().unsqueeze(-1).detach()
                
        return total_aux_loss, aux_losses_dict, aux_predictions_dict, aux_targets_dict

    def _prepare_decoder_inputs_targets(self, phone_seq, phone_seq_lens):
        # phone_seq: [B, T_target_phonemes] (original phonemes, padded with PAD_IDX)
        # phone_seq_lens: [B] (original lengths without PAD_IDX)
        
        B = phone_seq.size(0)
        # Max length for decoder input/target sequences will be original max_len + 1 (for SOS or EOS)
        # Ensure max_len_plus_1 is at least 1 to handle empty phone_seq (though unlikely)
        max_len_plus_1 = phone_seq.size(1) + 1 if phone_seq.size(1) > 0 else 1
        
        final_decoder_input = torch.full((B, max_len_plus_1), PAD_IDX, dtype=torch.long, device=self.device)
        final_decoder_target = torch.full((B, max_len_plus_1), PAD_IDX, dtype=torch.long, device=self.device)
        
        decoder_effective_lengths = phone_seq_lens + 1

        for i in range(B):
            true_len = phone_seq_lens[i].item()
            # Input: [SOS, p1, ..., p_true_len-1, PAD, ...]
            final_decoder_input[i, 0] = SOS_IDX
            if true_len > 0 : # Handle case where original sequence might be empty
                final_decoder_input[i, 1:true_len+1] = phone_seq[i, :true_len]
            
            # Target: [p1, ..., p_true_len-1, EOS, PAD, ...]
            if true_len > 0:
                final_decoder_target[i, :true_len] = phone_seq[i, :true_len]
            final_decoder_target[i, true_len] = EOS_IDX # EOS is always present
            
        return final_decoder_input, decoder_effective_lengths, final_decoder_target


    def training_step(self, batch, batch_idx):
        src_ecog = batch['neuralData']
        src_ecog_lens = batch['neuralDataLen']
        days = batch.get('day', None)
        phone_seq_targets_original = batch['phoneSeq']
        phone_seq_lens_original = batch['phoneSeqLen']
        target_text_sentences = batch.get('sentence', None) 

        # --- Augmentation (during training only) ---
        if self.training:
            # Additive Noise
            if self.white_noise_sd > 0:
                src_ecog += torch.randn(src_ecog.shape, device=self.device) * self.white_noise_sd
            if self.constant_offset_sd > 0:
                src_ecog += torch.randn([src_ecog.shape[0], 1, src_ecog.shape[2]], device=self.device) * self.constant_offset_sd
            
            # SpecAugment style masking
            src_ecog = self._apply_spec_augment(src_ecog, src_ecog_lens)

        total_loss = torch.tensor(0.0, device=self.device)
        
        # --- 0. Get Downsampled Lengths ---
        downsampled_src_ecog_lens = self._get_downsampled_lens(src_ecog_lens)

        # --- 1. Encoder Forward ---
        # Get final output for decoders and an intermediate output for auxiliary heads
        raw_ecog_encoder_outputs, intermediate_for_aux, memory_key_padding_mask = \
            self._encoder_forward(src_ecog, downsampled_src_ecog_lens, days)

        # --- 2. Calculate Auxiliary Losses and Predictions (using intermediate encoder output) ---
        aux_loss_val, current_aux_losses_dict, aux_predictions_dict, _ = \
            self._calculate_auxiliary_losses(intermediate_for_aux, batch, downsampled_src_ecog_lens)
        
        # --- 3. Prepare Encoder Representation for Phoneme Decoder (Potentially Augmented) ---
        final_encoder_representation_for_phoneme_decoder = raw_ecog_encoder_outputs
        if self.hparams.get('augment_encoder_with_mfcc', False) and \
           aux_predictions_dict is not None and 'mfcc' in aux_predictions_dict:
            predicted_mfccs_for_aug = aux_predictions_dict['mfcc']
            if predicted_mfccs_for_aug is not None:
                # _augment_encoder_outputs takes the raw encoder outputs and augments them
                final_encoder_representation_for_phoneme_decoder = self._augment_encoder_outputs(
                    raw_ecog_encoder_outputs, # Use raw ECoG encoder output here
                    predicted_mfccs_for_aug.detach()
                )
        
        # --- 4. Prepare Global Conditioning Vectors for Phoneme Decoder (if enabled) ---
        # These use aux_predictions_dict which are based on raw_ecog_encoder_outputs
        len_condition_vec = None
        wc_condition_vec = None
        if aux_predictions_dict is not None: # This check is good
            if self.hparams.get('augment_decoder_with_phoneme_len', False) and 'phoneme_len' in aux_predictions_dict:
                predicted_phoneme_len_scalar_for_aug = aux_predictions_dict['phoneme_len'] 
                if predicted_phoneme_len_scalar_for_aug is not None:
                    len_condition_vec = self.len_embedder(predicted_phoneme_len_scalar_for_aug.detach())

            if self.hparams.get('augment_decoder_with_word_count', False) and 'word_count' in aux_predictions_dict:
                predicted_word_count_scalar_for_aug = aux_predictions_dict['word_count']
                if predicted_word_count_scalar_for_aug is not None:
                    wc_condition_vec = self.word_count_embedder(predicted_word_count_scalar_for_aug.detach())

        # --- 5. Perform Main Task(s) ---
        
        # 5.A Phoneme Decoder Path (if not 'encoder_only' or 'secondary_only')
        if self.hparams.training_stage in ['joint_teacher_forcing', 'joint_sequential_generation']:
            if self.hparams.training_stage == 'joint_teacher_forcing':
                decoder_input, decoder_input_lens, decoder_target = \
                self._prepare_decoder_inputs_targets(phone_seq_targets_original, phone_seq_lens_original)

                logits = self._decoder_forward(
                        decoder_input, 
                        final_encoder_representation_for_phoneme_decoder, # Use potentially augmented ECoG output
                    memory_key_padding_mask,
                        decoder_input_lens, 
                    len_condition_vec=len_condition_vec,
                    wc_condition_vec=wc_condition_vec
                )
                main_loss_phoneme_unweighted = F.cross_entropy(
                    logits.reshape(-1, self.hparams.vocab_size),
                    decoder_target.reshape(-1),
                    ignore_index=PAD_IDX
                )
                phoneme_loss_weight = self.hparams.get('phoneme_main_loss_weight', 1.0)
                main_loss_phoneme_weighted = main_loss_phoneme_unweighted * phoneme_loss_weight

                self.log('train_main_phoneme_loss_tf_unweighted', main_loss_phoneme_unweighted, prog_bar=False, sync_dist=True, batch_size=src_ecog.size(0))
                if phoneme_loss_weight != 1.0:
                    self.log('train_main_phoneme_loss_tf_weighted', main_loss_phoneme_weighted, prog_bar=True, sync_dist=True, batch_size=src_ecog.size(0))
                else: # Log the unweighted also to prog_bar if weight is 1.0 for consistency
                    self.log('train_main_phoneme_loss_tf', main_loss_phoneme_unweighted, prog_bar=True, sync_dist=True, batch_size=src_ecog.size(0))

                total_loss += main_loss_phoneme_weighted

                if self.hparams.get('train_aux_in_joint_tf', False) and (hasattr(self, 'aux_heads') and len(self.aux_heads) > 0):
                    total_loss += aux_loss_val 
                    self.log('train_joint_tf_aux_loss', aux_loss_val, sync_dist=True, on_epoch=True, batch_size=src_ecog.size(0))
                    for name, val in current_aux_losses_dict.items():
                        self.log(f'train_joint_tf_{name}', val, sync_dist=True, on_epoch=True, batch_size=src_ecog.size(0))

            elif self.hparams.training_stage == 'joint_sequential_generation':
                _, _, decoder_target_sg = \
                    self._prepare_decoder_inputs_targets(phone_seq_targets_original, phone_seq_lens_original)

                guidance_lengths_for_generation = None
                use_predicted_len = self.hparams.get('use_predicted_len_for_sg', False)
                if use_predicted_len and aux_predictions_dict is not None and 'phoneme_len' in aux_predictions_dict:
                    predicted_lengths_float = aux_predictions_dict['phoneme_len'].squeeze(-1) 
                    self.log('predicted_guidance_len_mean', predicted_lengths_float.mean(), sync_dist=True, batch_size=src_ecog.size(0))
                    predicted_lengths_float = predicted_lengths_float + self.hparams.get('sg_gen_buffer', 0.0)
                    guidance_lengths_for_generation = torch.round(predicted_lengths_float).long().clamp(min=1).detach()
                elif use_predicted_len:
                    print("Warning: 'use_predicted_len_for_sg' is True, but 'phoneme_len' prediction not found.")

                max_gen_len_config = self.hparams.get('max_gen_len_train_sg', decoder_target_sg.size(1))
                use_gumbel = self.hparams.get('use_gumbel_for_sg_train', False)
                gumbel_temp = self.hparams.get('gumbel_tau_sg', 1.0)
                if use_gumbel:
                         self.log('gumbel_tau_sg', gumbel_temp, on_step=True, on_epoch=False, sync_dist=True, batch_size=src_ecog.size(0))

                generated_logits, _ = self._perform_sequential_generation(
                        final_encoder_representation_for_phoneme_decoder, # Use potentially augmented ECoG output
                    memory_key_padding_mask,
                    max_len=max_gen_len_config,
                    guidance_lengths=guidance_lengths_for_generation,
                    use_gumbel_softmax=use_gumbel,
                    gumbel_tau=gumbel_temp,
                        len_condition_vec=len_condition_vec, 
                    wc_condition_vec=wc_condition_vec
                )
                len_for_loss = min(generated_logits.size(1), decoder_target_sg.size(1))
                if len_for_loss > 0:
                        main_loss_sg_phoneme_unweighted = F.cross_entropy(
                        generated_logits[:, :len_for_loss, :].reshape(-1, self.hparams.vocab_size),
                        decoder_target_sg[:, :len_for_loss].reshape(-1),
                        ignore_index=PAD_IDX
                    )
                        phoneme_loss_weight = self.hparams.get('phoneme_main_loss_weight', 1.0)
                        main_loss_sg_phoneme_weighted = main_loss_sg_phoneme_unweighted * phoneme_loss_weight
                        
                        self.log('train_main_phoneme_loss_sg_unweighted', main_loss_sg_phoneme_unweighted, prog_bar=False, sync_dist=True, batch_size=src_ecog.size(0))
                        if phoneme_loss_weight != 1.0:
                            self.log('train_main_phoneme_loss_sg_weighted', main_loss_sg_phoneme_weighted, prog_bar=True, sync_dist=True, batch_size=src_ecog.size(0))
                        else: # Log the unweighted also to prog_bar if weight is 1.0 for consistency
                            self.log('train_main_phoneme_loss_sg', main_loss_sg_phoneme_unweighted, prog_bar=True, sync_dist=True, batch_size=src_ecog.size(0))
                        total_loss += main_loss_sg_phoneme_weighted
                else:
                        self.log('train_main_phoneme_loss_sg', torch.tensor(0.0, device=self.device), prog_bar=True, sync_dist=True, batch_size=src_ecog.size(0))
                # else: # This was an extraneous else, removed
                #      pass
        
        # 5.B Text Decoder Path (BART and/or Whisper)
        if self.bart_text_decoder_active and self.bart_model is not None and self.bart_tokenizer is not None and \
           self.hparams.training_stage in ['joint_teacher_forcing', 'joint_sequential_generation', 'secondary_only']: 
            if target_text_sentences is not None and all(isinstance(s, str) for s in target_text_sentences): # Check if list of str
                bart_labels = self.bart_tokenizer(
                    list(target_text_sentences), # Ensure list
                    return_tensors="pt", padding=True, truncation=True,
                    max_length=self.hparams.get('bart_max_target_len', 128)
                ).input_ids.to(self.device)
                
                # This mask refers to the padding of our ECoG encoder outputs
                bart_encoder_padding_mask = (~memory_key_padding_mask).long() if memory_key_padding_mask is not None else None
                
                # Project RAW ECoG encoder outputs for BART
                projected_raw_encoder_outputs_for_bart = self.ecog_to_bart_hidden_projection(raw_ecog_encoder_outputs)

                bart_outputs = self.bart_model(
                    encoder_outputs=(projected_raw_encoder_outputs_for_bart,), # Use projected RAW ECoG output
                    labels=bart_labels,
                    return_dict=True,
                    attention_mask=bart_encoder_padding_mask # RENAMED from encoder_attention_mask
                )
                bart_loss = bart_outputs.loss
                
                bart_loss_weight = self.hparams.get('bart_text_loss_weight', 1.0)
                weighted_bart_loss = bart_loss * bart_loss_weight
                total_loss += weighted_bart_loss
                self.log('train_bart_text_loss', bart_loss.detach(), prog_bar=False, sync_dist=True, batch_size=src_ecog.size(0))
                if bart_loss_weight != 1.0:
                    self.log('train_bart_text_loss_weighted', weighted_bart_loss.detach(), prog_bar=False, sync_dist=True, batch_size=src_ecog.size(0))
            else:
                if batch_idx == 0 and rank_zero_only.rank == 0:
                     print(f"Warning: 'sentence' key not found, is None, or not list of str in batch. Type: {type(target_text_sentences)}. Skipping BART text decoder training for this batch.")
        
        if self.whisper_text_decoder_active and self.whisper_model is not None and self.whisper_processor is not None and \
             self.hparams.training_stage in ['joint_teacher_forcing', 'joint_sequential_generation', 'secondary_only']:
            if target_text_sentences is not None and all(isinstance(s, str) for s in target_text_sentences):
                whisper_labels = self.whisper_processor(
                    text=list(target_text_sentences), return_tensors="pt", padding=True, truncation=True,
                    max_length=self.hparams.get('whisper_max_target_len', 128)
                ).input_ids.to(self.device)

                # decoder_input_ids = self.whisper_model._shift_right(whisper_labels)

                projected_raw_encoder_outputs_for_whisper = self.ecog_to_whisper_hidden_projection(raw_ecog_encoder_outputs)
                
                # The model will create `decoder_input_ids` from `labels` by shifting them right.
                whisper_outputs = self.whisper_model(
                    encoder_outputs=BaseModelOutput(last_hidden_state=projected_raw_encoder_outputs_for_whisper),
                    # decoder_input_ids=decoder_input_ids,
                    labels=whisper_labels,
                    return_dict=True
                )
                whisper_loss = whisper_outputs.loss

                whisper_loss_weight = self.hparams.get('whisper_text_loss_weight', 1.0)
                weighted_whisper_loss = whisper_loss * whisper_loss_weight
                total_loss += weighted_whisper_loss
                self.log('train_whisper_text_loss', whisper_loss.detach(), prog_bar=False, sync_dist=True, batch_size=src_ecog.size(0))
                if whisper_loss_weight != 1.0:
                    self.log('train_whisper_text_loss_weighted', weighted_whisper_loss.detach(), prog_bar=False, sync_dist=True, batch_size=src_ecog.size(0))
            else:
                if batch_idx == 0 and rank_zero_only.rank == 0:
                    print(f"Warning: 'sentence' key not found or not list of str. Skipping Whisper text decoder training for this batch.")

        # --- Finalizing Loss for 'encoder_only' stage ---
        if self.hparams.training_stage == 'encoder_only':
            if hasattr(self, 'aux_heads') and len(self.aux_heads) > 0:
                total_loss = aux_loss_val 
                self.log('train_total_aux_loss_enc_only', total_loss, prog_bar=True, sync_dist=True, batch_size=src_ecog.size(0))
                for name, val in current_aux_losses_dict.items():
                    self.log(f'train_{name}_enc_only', val, sync_dist=True, batch_size=src_ecog.size(0))
        
        # Add aux loss to other relevant stages if configured
        if self.hparams.training_stage in ['joint_teacher_forcing', 'joint_sequential_generation', 'secondary_only']:
            if self.hparams.training_stage == 'joint_teacher_forcing' and self.hparams.get('train_aux_in_joint_tf', False):
                total_loss += aux_loss_val
            elif self.hparams.training_stage == 'joint_sequential_generation' and self.hparams.get('train_aux_in_joint_sg', False):
                total_loss += aux_loss_val
            elif self.hparams.training_stage == 'secondary_only' and self.hparams.get('train_aux_in_secondary_only', False):
                total_loss += aux_loss_val

        # Keep only this final logging
        self.log('train_total_loss', total_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=src_ecog.size(0))
        return total_loss
    
    def _apply_spec_augment(self, x, x_lens):
        """
        Applies SpecAugment-style masking to a batch of sequences.
        This version applies a different mask to each sample in the batch.
        """
        batch_size, _, num_channels = x.shape
        num_physical_electrodes = num_channels // 2

        for i in range(batch_size):
            # --- Time Masking ---
            if self.time_masking_prob > 0 and self.time_mask_max_len > 0 and self.time_mask_max_proportion > 0:
                if torch.rand(1) < self.time_masking_prob:
                    seq_len = x_lens[i].item()
                    
                    # Determine max mask length for this specific sample
                    prop_max_len = int(seq_len * self.time_mask_max_proportion)
                    effective_max_len = min(self.time_mask_max_len, prop_max_len)
                    
                    if effective_max_len > 1:
                        # Choose a random mask length
                        t = torch.randint(1, effective_max_len + 1, (1,)).item()
                        # Choose a random start position within the true sequence length
                        t0 = torch.randint(0, seq_len - t + 1, (1,)).item()
                        x[i, t0:t0+t, :] = 0.0

            # --- Channel Masking ---
            if self.channel_masking_prob > 0 and self.channel_mask_max_electrodes > 0:
                if torch.rand(1) < self.channel_masking_prob:
                    # Choose a random number of electrodes to mask
                    c = torch.randint(1, self.channel_mask_max_electrodes + 1, (1,)).item()
                    
                    # Choose `c` random electrode indices to mask
                    electrodes_to_mask = torch.randperm(num_physical_electrodes)[:c]
                    
                    # Create the full list of channel indices (spike counts + band power)
                    spike_count_indices = electrodes_to_mask
                    band_power_indices = electrodes_to_mask + num_physical_electrodes
                    all_channels_to_mask = torch.cat([spike_count_indices, band_power_indices])

                    x[i, :, all_channels_to_mask] = 0.0
                    
        return x

    def _perform_sequential_generation(self, memory, memory_key_padding_mask, max_len,
                                        guidance_lengths=None,
                                        use_gumbel_softmax: bool = False,
                                        gumbel_tau: float = 1.0,
                                        len_condition_vec=None, # New
                                        wc_condition_vec=None): # New
        """
        Performs sequential generation using the decoder.
        `memory` here is potentially augmented encoder output.
        `len_condition_vec` and `wc_condition_vec` are for augmenting decoder inputs.
        """
        B = memory.size(0)
        device = memory.device

        # Initialize generated sequence with SOS token
        generated_seq = torch.full((B, 1), SOS_IDX, dtype=torch.long, device=device)
        
        # Keep track of sequences that have generated EOS
        batch_sequences_finished = torch.zeros(B, dtype=torch.bool, device=device)

        # Store all logits for loss calculation if needed (e.g., for SG training)
        all_logits_list = []

        for step in range(max_len - 1): # Max_len includes SOS, so generate max_len-1 more tokens
            if batch_sequences_finished.all():
                break

            # 1. Embed the current full sequence
            current_tgt_emb = self.phoneme_embedding(generated_seq) # [B, CurrentLen, phoneme_emb_dim]

            # 2. Augment embeddings if enabled and condition vectors are provided
            if self.hparams.get('augment_decoder_with_phoneme_len', False) or \
               self.hparams.get('augment_decoder_with_word_count', False):
                if len_condition_vec is not None or wc_condition_vec is not None:
                    current_tgt_emb = self._augment_decoder_input_embeddings(
                        current_tgt_emb,
                        len_condition_vec,
                        wc_condition_vec
                    )
            
            # 3. Apply scaling and positional encoding
            current_tgt_emb = self.pos_decoder(current_tgt_emb * math.sqrt(self.hparams.d_model))

            # 4. Create target mask for decoder
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(generated_seq.size(1), device=device)

            # 5. Decoder forward pass
            decoder_output = self.decoder(
                current_tgt_emb, # Full embedded and augmented sequence so far
                memory,          # Encoder memory (potentially augmented)
                tgt_mask=tgt_mask,
                memory_key_padding_mask=memory_key_padding_mask
            ) # [B, CurrentLen, d_model]

            # 6. Get logits for the *next* token prediction (from the last output position)
            current_step_logits = self.output_projection(decoder_output[:, -1, :]) # [B, vocab_size]
            all_logits_list.append(current_step_logits.unsqueeze(1)) # Store [B, 1, vocab_size]

            # 7. Select next token
            if use_gumbel_softmax:
                next_token_probs = F.gumbel_softmax(current_step_logits, tau=gumbel_tau, hard=True, dim=-1)
                next_token = torch.argmax(next_token_probs, dim=-1) # [B]
            else:
                next_token = torch.argmax(current_step_logits, dim=-1) # [B]

            # Mask out predictions for sequences that have already finished
            next_token = next_token.masked_fill(batch_sequences_finished, PAD_IDX)
            
            # Update generated sequence
            generated_seq = torch.cat([generated_seq, next_token.unsqueeze(1)], dim=1)

            # Update finished status
            batch_sequences_finished = batch_sequences_finished | (next_token == EOS_IDX)

            # Optional: Stop generation if guidance_lengths are met (after EOS check)
            if guidance_lengths is not None:
                # Number of actual phonemes generated (excluding SOS, including potential EOS)
                current_phoneme_counts = (generated_seq != SOS_IDX).sum(dim=1) -1 # -1 for SOS
                # if generated includes EOS, it counts towards length.
                # if guidance_lengths is for phonemes *before* EOS, adjust logic.
                # Assuming guidance_lengths is the target number of phonemes (EOS can be one of them or come after)
                reached_guidance_len = current_phoneme_counts >= guidance_lengths
                
                # Force EOS for sequences that reached guidance length but haven't naturally produced EOS
                force_eos_mask = reached_guidance_len & (~batch_sequences_finished)
                if force_eos_mask.any():
                    generated_seq[force_eos_mask, -1] = EOS_IDX # Overwrite last token with EOS
                    batch_sequences_finished = batch_sequences_finished | force_eos_mask


        # Concatenate all logits along the sequence dimension
        if all_logits_list:
            full_logits = torch.cat(all_logits_list, dim=1) # [B, GenLen-1, VocabSize]
        else: # No actual generation steps occurred (e.g., max_len <= 1)
            full_logits = torch.empty((B, 0, self.hparams.vocab_size), device=device, dtype=memory.dtype)

        # Pad generated_seq to max_len if it's shorter (e.g., all sequences finished early)
        # This is mostly for consistency in return shape, loss calculation should use actual lengths.
        if generated_seq.size(1) < max_len:
            padding_size = max_len - generated_seq.size(1)
            padding = torch.full((B, padding_size), PAD_IDX, dtype=torch.long, device=device)
            generated_seq = torch.cat([generated_seq, padding], dim=1)
        
        return full_logits, generated_seq


    # --- Helper method for Encoder Augmentation ---
    def _augment_encoder_outputs(self, encoder_outputs, predicted_mfccs):
        # predicted_mfccs shape: [B, T_enc, mfcc_dim]
        # encoder_outputs shape: [B, T_enc, d_model]
        
        # This layer (self.encoder_feature_fusion_projection) is always initialized.
        # We rely on the hparam check in the calling function (e.g., training_step)
        # to decide whether this augmentation should occur.

        # Ensure correct sequence lengths if MFCCs were downsampled differently
        T_enc_out = encoder_outputs.size(1)
        T_mfcc = predicted_mfccs.size(1)

        if T_enc_out != T_mfcc:
            # Simple alignment: truncate the longer one to match the shorter one's length.
            # More sophisticated alignment might be needed depending on the cause of mismatch.
            min_len = min(T_enc_out, T_mfcc)
            if T_enc_out > min_len:
                encoder_outputs = encoder_outputs[:, :min_len, :]
                # print(f"Warning: Truncating encoder_outputs from {T_enc_out} to {min_len} to match MFCCs for augmentation.")
            if T_mfcc > min_len:
                predicted_mfccs = predicted_mfccs[:, :min_len, :]
                # print(f"Warning: Truncating predicted_mfccs from {T_mfcc} to {min_len} to match encoder_outputs for augmentation.")
        
        # Detach MFCC predictions as they are conditioning signals here
        combined_features = torch.cat((encoder_outputs, predicted_mfccs.detach()), dim=-1)
        augmented_encoder_outputs = self.encoder_feature_fusion_projection(combined_features)
        return augmented_encoder_outputs

    # --- Helper method for Decoder Input Augmentation ---
    def _augment_decoder_input_embeddings(self, phoneme_embeddings, len_condition_vec, wc_condition_vec):
        # phoneme_embeddings shape: [B, T_dec, phoneme_emb_dim]
        # len_condition_vec shape: [B, global_feature_emb_dim] or None
        # wc_condition_vec shape: [B, global_feature_emb_dim] or None

        # These layers (len_embedder, word_count_embedder, decoder_input_fusion_projection)
        # are always initialized. The calling function checks hparams to decide if augmentation happens.

        active_conditions_to_cat = [phoneme_embeddings]
        
        # Check hparams again here to ensure we only use active augmentations
        if self.hparams.get('augment_decoder_with_phoneme_len', False) and len_condition_vec is not None:
            # len_condition_vec is already projected by self.len_embedder and detached in training_step
            active_conditions_to_cat.append(
                len_condition_vec.unsqueeze(1).repeat(1, phoneme_embeddings.size(1), 1)
            )
        
        if self.hparams.get('augment_decoder_with_word_count', False) and wc_condition_vec is not None:
            # wc_condition_vec is already projected by self.word_count_embedder and detached in training_step
            active_conditions_to_cat.append(
                wc_condition_vec.unsqueeze(1).repeat(1, phoneme_embeddings.size(1), 1)
            )

        if len(active_conditions_to_cat) == 1: # Only phoneme_embeddings, no global features were active/provided
            return phoneme_embeddings
        
        concatenated_embeddings = torch.cat(active_conditions_to_cat, dim=-1)
        
        # The self.decoder_input_fusion_projection was defined with max_possible_global_features.
        # If fewer features are concatenated, this projection layer might expect a different input size.
        # This needs to be handled: either the projection layer is dynamic, or we pad the input,
        # or we ensure the concatenation always results in the expected size (e.g. by zero-padding if a feature is off).

        # Let's adjust the __init__ for decoder_input_fusion_projection to be more robust,
        # or make this helper more careful.
        # For now, assuming decoder_input_fusion_projection's input dim matches the current concatenation.
        # This means __init__ needs to be more dynamic or this function needs to pad.

        # Revisit: The __init__ defines decoder_input_fusion_projection based on max_possible_global_features.
        # If only one global feature is active, the concatenated_embeddings dim will be smaller.
        # Solution: Create a zero tensor for inactive features to maintain consistent input dim for the projection layer.

        expected_total_dim = self.hparams.d_model
        processed_conditions_for_cat = [phoneme_embeddings]

        if self.hparams.get('augment_decoder_with_phoneme_len', False):
            if len_condition_vec is not None:
                processed_conditions_for_cat.append(len_condition_vec.unsqueeze(1).repeat(1, phoneme_embeddings.size(1), 1))
            else: # Augmentation enabled, but no vector provided (should not happen if logic is correct)
                  # Add zeros to maintain dimension for the fusion projection layer
                zeros_for_len = torch.zeros(
                    phoneme_embeddings.size(0), 
                    phoneme_embeddings.size(1), 
                    self.hparams.global_feature_emb_dim, 
                    device=phoneme_embeddings.device, 
                    dtype=phoneme_embeddings.dtype
                )
                processed_conditions_for_cat.append(zeros_for_len)
            expected_total_dim += self.hparams.global_feature_emb_dim
        
        if self.hparams.get('augment_decoder_with_word_count', False):
            if wc_condition_vec is not None:
                processed_conditions_for_cat.append(wc_condition_vec.unsqueeze(1).repeat(1, phoneme_embeddings.size(1), 1))
            else: # Augmentation enabled, but no vector provided
                zeros_for_wc = torch.zeros(
                    phoneme_embeddings.size(0), 
                    phoneme_embeddings.size(1), 
                    self.hparams.global_feature_emb_dim, 
                    device=phoneme_embeddings.device, 
                    dtype=phoneme_embeddings.dtype
                )
                processed_conditions_for_cat.append(zeros_for_wc)
            expected_total_dim += self.hparams.global_feature_emb_dim

        if len(processed_conditions_for_cat) == 1: # Still only phoneme_embeddings
             return phoneme_embeddings

        final_concatenated_embeddings = torch.cat(processed_conditions_for_cat, dim=-1)

        # Safety check for dimension, though the logic above should ensure it.
        if final_concatenated_embeddings.size(-1) != self.decoder_input_fusion_projection.in_features:
            # This indicates a mismatch between the dynamically concatenated features
            # and the expected input dimension of the fusion layer.
            # This should be resolved by the zero-padding logic above.
            # If this error still occurs, the __init__ for decoder_input_fusion_projection
            # or this concatenation logic needs further refinement.
            raise RuntimeError(f"Dimension mismatch for decoder_input_fusion_projection. "
                               f"Expected {self.decoder_input_fusion_projection.in_features}, "
                               f"got {final_concatenated_embeddings.size(-1)}")
        
        augmented_embeddings = self.decoder_input_fusion_projection(final_concatenated_embeddings)
        return augmented_embeddings

    def _get_token_str(self, token_id):
        return self.idx_to_token_map.get(token_id, str(token_id))
    
    def _ids_to_phoneme_list(self, id_sequence_tensor, strip_sos=True, strip_pad=True, keep_eos_str=True, stop_at_eos=True):
        """Converts a 1D tensor of token IDs to a list of phoneme strings."""
        if id_sequence_tensor.ndim == 0: # Handle single ID tensor
            id_sequence_tensor = id_sequence_tensor.unsqueeze(0)
        
        phoneme_list = []
        for token_id_val in id_sequence_tensor.tolist():
            if strip_sos and token_id_val == self.hparams.sos_idx:
                continue
            if token_id_val == self.hparams.eos_idx:
                if keep_eos_str:
                    phoneme_list.append(self._get_token_str(token_id_val))
                if stop_at_eos:
                    break # Stop at EOS
                else:
                    continue
            if strip_pad and token_id_val == self.hparams.pad_idx:
                continue # Skip PAD tokens if they appear before EOS (should ideally not happen for valid seqs)

            phoneme_list.append(self._get_token_str(token_id_val))
        return phoneme_list
    
    def _calculate_per(self, ref_ids_tensor, hyp_ids_tensor):
        """Calculates Phoneme Error Rate (PER).
        Assumes ref_ids_tensor and hyp_ids_tensor are 1D tensors of token IDs.
        SOS, PAD are stripped. EOS is NOT stripped before PER calculation.
        """
        # Convert to phoneme lists, stripping SOS, PAD, and not keeping EOS string for PER
        ref_phonemes = self._ids_to_phoneme_list(ref_ids_tensor, strip_sos=True, strip_pad=True, keep_eos_str=True)
        hyp_phonemes = self._ids_to_phoneme_list(hyp_ids_tensor, strip_sos=True, strip_pad=True, keep_eos_str=True)

        if not ref_phonemes: # Avoid division by zero if reference is empty after cleaning
            return 1.0 if hyp_phonemes else 0.0 # 100% error if hyp has content, 0% if both empty

        distance = editdistance.eval(ref_phonemes, hyp_phonemes)
        per = float(distance) / len(ref_phonemes)
        return per

    def _preprocess_text(self, text: str) -> str:
        """
        Remove punctuation, strip, and convert text to lowercase.
        """
        if not isinstance(text, str):
            return "" # Or handle as an error
        return text.translate(str.maketrans('', '', string.punctuation)).strip().lower()

    def _get_true_length(self, id_sequence_tensor):
        """Calculates the true length of a sequence, excluding SOS, PAD and EOS."""
        phoneme_list_for_length = self._ids_to_phoneme_list(
            id_sequence_tensor, 
            strip_sos=True, 
            strip_pad=True, # This will strip pads that might be part of the "true" length if not careful
                            # Let's refine: count non-PAD, non-SOS up to and including EOS
            keep_eos_str=False
        )

        return len(phoneme_list_for_length)
    
    def on_validation_epoch_start(self):
        # --- Setup for Local CSV Logging and metric aggregation ---
        self.all_pers_tf = []
        self.all_pers_sg = []
        self.all_wers_bart = []
        self.all_wers_whisper = []
        self.aux_error_sums = {k: 0.0 for k in self.aux_heads if k in ['phoneme_len', 'word_count']}
        self.aux_error_counts = {k: 0 for k in self.aux_heads if k in ['phoneme_len', 'word_count']}
        self.global_sample_idx = 0
        self.csv_file_handle = None
        self.csv_writer = None

        if self.global_rank == 0:
            # New logic: Prefer the precise path passed from the training script
            if hasattr(self, 'base_run_dir'):
                base_save_dir = Path(self.base_run_dir)
            else:
                # Fallback to old logic for backward compatibility
                project_name = "default_project"
                run_name = "default_run"
                if hasattr(self, 'logger') and hasattr(self.logger, 'experiment'):
                    exp_project = getattr(self.logger.experiment, 'project', project_name)
                    project_name = exp_project if exp_project is not None else project_name
                    
                    exp_run_name = getattr(self.logger.experiment, 'name', None)
                    if exp_run_name:
                        run_name = exp_run_name
                    elif hasattr(self.logger.experiment, 'id') and self.logger.experiment.id is not None:
                        run_name = f"run_{self.logger.experiment.id}"

                base_save_dir = Path("experiments") / str(project_name) / str(run_name)
            
            validation_csv_dir = base_save_dir / "validation"
            try:
                os.makedirs(validation_csv_dir, exist_ok=True)
                csv_filename = validation_csv_dir / f"validation_results_epoch_{self.current_epoch}.csv"
                if os.path.exists(csv_filename):
                    os.remove(csv_filename)
                
                # Use the 'csv' module for robust, incremental writing
                import csv
                self.csv_file_handle = open(csv_filename, 'w', newline='')
                # Define a comprehensive header, including all possible fields
                fieldnames = [
                    "idx", "epoch", "day", "sentence", "target_phonemes", "len_tgt",
                    "predicted_phonemes_tf", "len_tf", "per_tf",
                    "predicted_phonemes_sg", "len_sg", "per_sg",
                    "predicted_text_bart", "wer_bart",
                    "predicted_text_whisper", "wer_whisper"
                ]
                # Dynamically add aux task fields to header
                for aux_key in self.aux_heads.keys():
                    fieldnames.extend([f"pred_{aux_key}", f"target_{aux_key}", f"{aux_key}_error", f"{aux_key}_mse"])
                
                self.csv_writer = csv.DictWriter(self.csv_file_handle, fieldnames=fieldnames, extrasaction='ignore')
                self.csv_writer.writeheader()

            except Exception as e:
                print(f"Error initializing validation CSV logging: {e}")
                if self.csv_file_handle:
                    self.csv_file_handle.close()
                self.csv_writer = None
    
    def validation_step(self, batch, batch_idx):
        src_ecog = batch['neuralData']
        src_ecog_lens = batch['neuralDataLen']
        days = batch.get('day', None)
        phone_seq_targets_original = batch['phoneSeq']
        phone_seq_lens_original = batch['phoneSeqLen']
        target_text_sentences_val = batch.get('sentence', None)

        output_losses = {} # This will store only scalar losses for this step

        # --- 0. Get Downsampled Lengths ---
        downsampled_src_ecog_lens_val = self._get_downsampled_lens(src_ecog_lens)

        # --- Encoder Forward (Raw ECoG Encoder Output) ---
        raw_ecog_encoder_outputs_val, intermediate_for_aux_val, memory_key_padding_mask = \
            self._encoder_forward(src_ecog, downsampled_src_ecog_lens_val, days)

        # --- Auxiliary Task Evaluation (using intermediate encoder outputs) ---
        aux_predictions_dict_val = None
        aux_targets_dict_val = None # Initialize here to prevent UnboundLocalError
        if self.aux_heads: 
            val_total_aux_loss, val_aux_losses_dict, current_aux_preds_step, current_aux_targets_step = \
                self._calculate_auxiliary_losses(intermediate_for_aux_val, batch, downsampled_src_ecog_lens_val)
            
            output_losses['val_total_aux_loss'] = val_total_aux_loss
            for name, loss_val in val_aux_losses_dict.items():
                self.log(f'val_{name}_step', loss_val, on_step=True, on_epoch=False, sync_dist=True, batch_size=src_ecog.size(0))
                output_losses[f'val_{name}'] = loss_val
            
            if current_aux_preds_step:
                # Move to CPU for processing, then discard
                aux_predictions_dict_val = {k: v.cpu() for k, v in current_aux_preds_step.items()}
            if current_aux_targets_step:
                aux_targets_dict_val = {k: v.cpu() for k, v in current_aux_targets_step.items()}

        # --- Prepare Encoder Representation for Phoneme Decoder (Potentially Augmented) ---
        final_encoder_representation_for_phoneme_decoder_val = raw_ecog_encoder_outputs_val
        if self.hparams.get('augment_encoder_with_mfcc', False) and \
           aux_predictions_dict_val is not None and 'mfcc' in aux_predictions_dict_val:
            predicted_mfccs_for_aug_val = aux_predictions_dict_val['mfcc']
            if predicted_mfccs_for_aug_val is not None:
                final_encoder_representation_for_phoneme_decoder_val = self._augment_encoder_outputs(
                    raw_ecog_encoder_outputs_val, # Augment the raw output
                    predicted_mfccs_for_aug_val.to(raw_ecog_encoder_outputs_val.device).detach()
                )

        # --- Prepare Global Conditioning Vectors for Phoneme Decoder ---
        len_condition_vec_val = None
        wc_condition_vec_val = None
        if aux_predictions_dict_val is not None:
            if self.hparams.get('augment_decoder_with_phoneme_len', False) and 'phoneme_len' in aux_predictions_dict_val:
                predicted_len_scalar_val = aux_predictions_dict_val['phoneme_len']
                if predicted_len_scalar_val is not None:
                    len_condition_vec_val = self.len_embedder(predicted_len_scalar_val.to(self.device).detach())
            if self.hparams.get('augment_decoder_with_word_count', False) and 'word_count' in aux_predictions_dict_val:
                predicted_wc_scalar_val = aux_predictions_dict_val['word_count']
                if predicted_wc_scalar_val is not None:
                    wc_condition_vec_val = self.word_count_embedder(predicted_wc_scalar_val.to(self.device).detach())
        
        # --- Teacher Forcing (TF) Path for Phoneme Decoder ---
        decoder_input_tf, decoder_input_lens_tf, decoder_target_tf = \
            self._prepare_decoder_inputs_targets(phone_seq_targets_original, phone_seq_lens_original)
        logits_tf = self._decoder_forward(
            decoder_input_tf, final_encoder_representation_for_phoneme_decoder_val, 
            memory_key_padding_mask, decoder_input_lens_tf,
            len_condition_vec=len_condition_vec_val, wc_condition_vec=wc_condition_vec_val
        )
        val_loss_tf = F.cross_entropy(
            logits_tf.reshape(-1, self.hparams.vocab_size), decoder_target_tf.reshape(-1),
            ignore_index=self.hparams.pad_idx 
        )
        output_losses['val_loss_tf'] = val_loss_tf
        pred_ids_tf = logits_tf.argmax(-1)
        preds_tf_list = []
        for i in range(pred_ids_tf.size(0)):
            true_len = phone_seq_lens_original[i] + 1 
            preds_tf_list.append(pred_ids_tf[i, :true_len].cpu())
        
        targets_batch_cpu = decoder_target_tf.cpu()

        # --- BART Text Decoder Validation ---
        generated_bart_ids_cpu = None
        if self.bart_text_decoder_active and self.bart_model is not None and self.bart_tokenizer is not None:
            if target_text_sentences_val is not None and all(isinstance(s, str) for s in target_text_sentences_val):
                bart_labels_val = self.bart_tokenizer(
                    list(target_text_sentences_val), return_tensors="pt", padding=True, truncation=True,
                    max_length=self.hparams.get('bart_max_target_len', 128)
                ).input_ids.to(self.device)
                
                # This mask refers to the padding of our ECoG encoder outputs
                bart_encoder_padding_mask_val = (~memory_key_padding_mask).long() if memory_key_padding_mask is not None else None
                projected_raw_encoder_outputs_for_bart_val = self.ecog_to_bart_hidden_projection(raw_ecog_encoder_outputs_val)

                bart_outputs_val = self.bart_model(
                    encoder_outputs=(projected_raw_encoder_outputs_for_bart_val,), # Use projected RAW ECoG output
                    labels=bart_labels_val, return_dict=True,
                    attention_mask=bart_encoder_padding_mask_val # RENAMED from encoder_attention_mask
                )
                output_losses['val_loss_bart_tf'] = bart_outputs_val.loss

                max_gen_len_bart = self.hparams.get('bart_max_gen_len_val', 150)
                # For generation, attention_mask is also used for the encoder outputs
                bart_gen_encoder_padding_mask = (~memory_key_padding_mask).int() if memory_key_padding_mask is not None else None
                
                # Wrap encoder_outputs for generate method
                wrapped_encoder_outputs_val = BaseModelOutput(last_hidden_state=projected_raw_encoder_outputs_for_bart_val.contiguous())

                generated_bart_ids = self.bart_model.generate(
                    encoder_outputs=wrapped_encoder_outputs_val, # Use wrapped outputs
                    attention_mask=bart_gen_encoder_padding_mask, 
                    max_length=max_gen_len_bart,
                    num_beams=self.hparams.get('bart_val_num_beams', 4), early_stopping=True,
                    decoder_start_token_id=self.bart_model.config.decoder_start_token_id
                )
                generated_bart_ids_cpu = generated_bart_ids.cpu()
            else:
                if rank_zero_only.rank == 0 and batch_idx == 0:
                    print(f"Warning: target_text_sentences_val is None or not list of str in val_step. Type: {type(target_text_sentences_val)}. Skipping BART val.")
                output_losses['val_loss_bart_tf'] = torch.tensor(float('nan'))
                generated_bart_ids_cpu = [torch.empty(0, dtype=torch.long) for _ in range(src_ecog.size(0))]
        
        # --- Whisper Text Decoder Validation (New) ---
        generated_whisper_ids_cpu = None
        if self.whisper_text_decoder_active and self.whisper_model is not None and self.whisper_processor is not None:
            projected_raw_encoder_outputs_for_whisper_val = self.ecog_to_whisper_hidden_projection(raw_ecog_encoder_outputs_val)
            wrapped_encoder_outputs_val = BaseModelOutput(last_hidden_state=projected_raw_encoder_outputs_for_whisper_val.contiguous())

            if target_text_sentences_val is not None and all(isinstance(s, str) for s in target_text_sentences_val):
                whisper_labels_val = self.whisper_processor(
                    text=list(target_text_sentences_val), return_tensors="pt", padding=True, truncation=True,
                    max_length=self.hparams.get('whisper_max_target_len', 128)
                ).input_ids.to(self.device)
                
                # decoder_input_ids_val = self.whisper_model._shift_right(whisper_labels_val) # REMOVED: Whisper model handles this internally
                
                # The model will create `decoder_input_ids` from `labels` by shifting them right.
                whisper_outputs_val = self.whisper_model(
                    encoder_outputs=wrapped_encoder_outputs_val,
                    # decoder_input_ids=decoder_input_ids_val, # REMOVED
                    labels=whisper_labels_val, return_dict=True
                )
                output_losses['val_loss_whisper_tf'] = whisper_outputs_val.loss

            max_gen_len_whisper = self.hparams.get('whisper_max_gen_len_val', 150)
            
            # Create attention mask for generation (1 for real tokens, 0 for padding)
            whisper_encoder_padding_mask_val = (~memory_key_padding_mask).long() if memory_key_padding_mask is not None else None

            # Get forced decoder IDs for English transcription task
            forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="en", task="transcribe")

            generated_whisper_ids = self.whisper_model.generate(
                encoder_outputs=wrapped_encoder_outputs_val,
                attention_mask=whisper_encoder_padding_mask_val, # Add attention mask
                max_length=max_gen_len_whisper,
                num_beams=self.hparams.get('whisper_val_num_beams', 4),
                early_stopping=True,
                forced_decoder_ids=forced_decoder_ids # Use explicit forced decoder IDs
            )
            generated_whisper_ids_cpu = generated_whisper_ids.cpu()

        # --- Sequential Generation (SG) Path for Phoneme Decoder ---
        gen_ids_sg_cpu = None
        run_sg_validation = False
        sg_val_freq = self.hparams.get('sg_val_every_n_epochs', 1)
        if sg_val_freq > 0 and (self.current_epoch + 1) % sg_val_freq == 0:
            run_sg_validation = True

        if run_sg_validation:
            max_gen_len_val = self.hparams.get('max_gen_len_val', decoder_target_tf.size(1) + 10) 
            guidance_lengths_for_sg_val = None
            if self.hparams.get('use_predicted_len_for_sg', False):
                if aux_predictions_dict_val and 'phoneme_len' in aux_predictions_dict_val:
                    raw_len_preds = aux_predictions_dict_val['phoneme_len'].squeeze(-1) 
                    raw_len_preds = raw_len_preds + self.hparams.get('sg_gen_buffer', 0.0)
                    guidance_lengths_for_sg_val = torch.round(raw_len_preds).long().to(self.device) 
                elif rank_zero_only.rank == 0 and batch_idx == 0: 
                    print("Warning: 'use_predicted_len_for_sg' True, but 'phoneme_len' not in aux_predictions_dict_val for SG val.")

            gen_logits_sg, gen_ids_sg = self._perform_sequential_generation(
                final_encoder_representation_for_phoneme_decoder_val, # Potentially augmented
                memory_key_padding_mask, max_len=max_gen_len_val,
                guidance_lengths=guidance_lengths_for_sg_val, use_gumbel_softmax=False, 
                len_condition_vec=len_condition_vec_val, wc_condition_vec=wc_condition_vec_val
            )
            gen_ids_sg_cpu = gen_ids_sg.cpu() 
            len_for_loss_sg = min(gen_logits_sg.size(1), decoder_target_tf.size(1))
            if len_for_loss_sg > 0:
                val_loss_sg = F.cross_entropy(
                    gen_logits_sg[:, :len_for_loss_sg, :].reshape(-1, self.hparams.vocab_size),
                    decoder_target_tf[:, :len_for_loss_sg].reshape(-1), ignore_index=self.hparams.pad_idx
                )
                output_losses['val_loss_sg'] = val_loss_sg
            else:
                output_losses['val_loss_sg'] = torch.tensor(float('nan'))
        
        # --- Process this batch's results and write to CSV ---
        if self.global_rank == 0 and self.csv_writer is not None:
            batch_records = self._process_validation_batch_for_logging(
                batch, preds_tf_list, targets_batch_cpu, gen_ids_sg_cpu, 
                generated_bart_ids_cpu, generated_whisper_ids_cpu,
                aux_predictions_dict_val, aux_targets_dict_val
            )
            try:
                self.csv_writer.writerows(batch_records)
            except Exception as e:
                print(f"Error writing to CSV in validation_step: {e}")

        # DO NOT append large tensors to self.validation_step_outputs
        # self.validation_step_outputs.append(...) # This is the line that caused the OOM error
        
        # Return only the scalar losses which PL will aggregate
        return output_losses

    def _process_validation_batch_for_logging(self, batch, preds_tf_batch, targets_batch, preds_sg_batch, 
                                               preds_bart_text_ids_batch, preds_whisper_text_ids_batch, 
                                               aux_predictions_batch, aux_targets_batch):
        """
        Process one batch of validation results into a list of lightweight record dictionaries.
        This is a helper function to keep validation_step cleaner.
        """
        batch_records = []
        current_batch_size = batch['neuralData'].size(0)
        days_batch = batch.get('day').cpu().tolist() if 'day' in batch else ["N/A"] * current_batch_size
        target_text_batch = batch.get('sentence', ["N/A"] * current_batch_size)

        for i in range(current_batch_size):
            current_record = {"idx": self.global_sample_idx, "epoch": self.current_epoch, "day": days_batch[i]}
            self.global_sample_idx += 1
            
            # Target processing
            target_ids_sample = targets_batch[i] if targets_batch is not None and i < len(targets_batch) else None
            if target_ids_sample is not None:
                current_record["target_phonemes"] = " ".join(self._ids_to_phoneme_list(target_ids_sample, strip_sos=False, strip_pad=True, keep_eos_str=True))
                current_record["len_tgt"] = self._get_true_length(target_ids_sample)
            
            current_record["sentence"] = target_text_batch[i] if i < len(target_text_batch) else "N/A"
            current_record["target_text"] = current_record["sentence"]

            # TF prediction processing
            if preds_tf_batch is not None and i < len(preds_tf_batch):
                pred_tf_ids_sample = preds_tf_batch[i]
                current_record["predicted_phonemes_tf"] = " ".join(self._ids_to_phoneme_list(pred_tf_ids_sample, strip_sos=False, strip_pad=True, keep_eos_str=True, stop_at_eos=False))
                current_record["len_tf"] = self._get_true_length(pred_tf_ids_sample)
                if target_ids_sample is not None:
                    per_tf = self._calculate_per(target_ids_sample, pred_tf_ids_sample)
                    self.all_pers_tf.append(per_tf)
                    current_record["per_tf"] = f"{per_tf:.4f}"

            # SG prediction processing
            if preds_sg_batch is not None and i < preds_sg_batch.size(0):
                pred_sg_ids_sample = preds_sg_batch[i]
                current_record["predicted_phonemes_sg"] = " ".join(self._ids_to_phoneme_list(pred_sg_ids_sample, strip_sos=True, strip_pad=True, keep_eos_str=True))
                current_record["len_sg"] = self._get_true_length(pred_sg_ids_sample)
                if target_ids_sample is not None:
                    per_sg = self._calculate_per(target_ids_sample, pred_sg_ids_sample)
                    self.all_pers_sg.append(per_sg)
                    current_record["per_sg"] = f"{per_sg:.4f}"
            
            # BART prediction processing
            if self.bart_text_decoder_active and preds_bart_text_ids_batch is not None and i < len(preds_bart_text_ids_batch):
                single_bart_pred_ids = preds_bart_text_ids_batch[i]
                if isinstance(single_bart_pred_ids, torch.Tensor) and single_bart_pred_ids.numel() > 0:
                    pred_text_list = self.bart_tokenizer.batch_decode(single_bart_pred_ids.unsqueeze(0), skip_special_tokens=True, clean_up_tokenization_spaces=True)
                    predicted_text_bart_sample = pred_text_list[0] if pred_text_list else "N/A"
                    current_record["predicted_text_bart"] = predicted_text_bart_sample
                    
                    # WER Calculation
                    target_text_sample = current_record.get("target_text", "N/A")
                    if isinstance(target_text_sample, str) and target_text_sample != "N/A":
                        processed_target = self._preprocess_text(target_text_sample)
                        processed_pred = self._preprocess_text(predicted_text_bart_sample)
                        if processed_target:
                            wer = jiwer.wer(processed_target, processed_pred)
                            self.all_wers_bart.append(wer)
                            current_record["wer_bart"] = f"{wer:.4f}"

            # Whisper prediction processing
            if self.whisper_text_decoder_active and preds_whisper_text_ids_batch is not None and i < len(preds_whisper_text_ids_batch):
                single_whisper_pred_ids = preds_whisper_text_ids_batch[i]
                if isinstance(single_whisper_pred_ids, torch.Tensor) and single_whisper_pred_ids.numel() > 0:
                    pred_text_list_whisper = self.whisper_tokenizer.batch_decode(single_whisper_pred_ids.unsqueeze(0), skip_special_tokens=True, clean_up_tokenization_spaces=True)
                    predicted_text_whisper_sample = pred_text_list_whisper[0] if pred_text_list_whisper else "N/A"
                    current_record["predicted_text_whisper"] = predicted_text_whisper_sample

                    target_text_sample = current_record.get("target_text", "N/A")
                    if isinstance(target_text_sample, str) and target_text_sample != "N/A":
                        processed_target = self._preprocess_text(target_text_sample)
                        processed_pred = self._preprocess_text(predicted_text_whisper_sample)
                        if processed_target:
                            wer_whisper = jiwer.wer(processed_target, processed_pred)
                            self.all_wers_whisper.append(wer_whisper)
                            current_record["wer_whisper"] = f"{wer_whisper:.4f}"

            # Aux processing
            if aux_predictions_batch and aux_targets_batch:
                for key in self.aux_heads.keys():
                    if key in aux_predictions_batch and key in aux_targets_batch:
                         pred_val, target_val = aux_predictions_batch[key][i], aux_targets_batch[key][i]
                         current_record[f"pred_{key}"] = f"{pred_val.item():.2f}" if pred_val.numel() == 1 else "N/A"
                         current_record[f"target_{key}"] = f"{target_val.item():.2f}" if target_val.numel() == 1 else "N/A"
                         if key in ['phoneme_len', 'word_count'] and pred_val.numel() == 1 and target_val.numel() == 1:
                            error = target_val.item() - pred_val.item()
                            self.aux_error_sums[key] += error
                            self.aux_error_counts[key] += 1
                            current_record[f"{key}_error"] = error
                         elif key in ['mfcc', 'envelope']:
                             # pred_val shape is [T_pred, D], target_val shape is [T_target, D]
                             max_len = min(pred_val.shape[0], target_val.shape[0])
                             if max_len > 0:
                                 mse = F.mse_loss(pred_val[:max_len], target_val[:max_len], reduction='mean').item()
                                 current_record[f"{key}_mse"] = mse
                             else:
                                 current_record[f"{key}_mse"] = "N/A"

            batch_records.append(current_record)
        return batch_records

    def on_validation_epoch_end(self):
        # --- Close CSV File Handle ---
        if self.global_rank == 0 and self.csv_file_handle is not None:
            try:
                self.csv_file_handle.close()
                print(f"Finished writing validation results for epoch {self.current_epoch}.")
            except Exception as e:
                print(f"Error closing validation CSV file: {e}")
        
        # --- Log Aggregated Scalar Metrics ---
        if self.all_pers_tf:
            valid_pers_tf = [p for p in self.all_pers_tf if isinstance(p, float)]
            if valid_pers_tf: self.log('val_per_tf_epoch', torch.tensor(valid_pers_tf).mean().item(), prog_bar=True, sync_dist=True)
        
        if self.all_pers_sg:
            valid_pers_sg = [p for p in self.all_pers_sg if isinstance(p, float)]
            if valid_pers_sg: self.log('val_per_sg_epoch', torch.tensor(valid_pers_sg).mean().item(), prog_bar=True, sync_dist=True)

        if self.bart_text_decoder_active and self.all_wers_bart:
            valid_wers_bart = [w for w in self.all_wers_bart if isinstance(w, float)]
            if valid_wers_bart:
                avg_wer_bart = torch.tensor(valid_wers_bart).mean().item()
                self.log('val_wer_bart_epoch', avg_wer_bart, prog_bar=True, sync_dist=True)

        if self.whisper_text_decoder_active and self.all_wers_whisper:
            valid_wers_whisper = [w for w in self.all_wers_whisper if isinstance(w, float)]
            if valid_wers_whisper:
                avg_wer_whisper = torch.tensor(valid_wers_whisper).mean().item()
                self.log('val_wer_whisper_epoch', avg_wer_whisper, prog_bar=True, sync_dist=True)

        for aux_task_key in self.aux_error_sums:
            if self.aux_error_counts[aux_task_key] > 0:
                avg_error = self.aux_error_sums[aux_task_key] / self.aux_error_counts[aux_task_key]
                self.log(f'val_{aux_task_key}_avg_error', avg_error, sync_dist=True)

        # Manually log aggregated losses. PL aggregates `val_loss_tf` from step outputs automatically.
        # This is just for explicit clarity if needed, but PL's automatic logging is usually sufficient.
        # Example: self.log('val_loss_tf_epoch', ...) 

        # Clear the attributes for the next validation epoch
        del self.all_pers_tf, self.all_pers_sg, self.all_wers_bart, self.all_wers_whisper, self.aux_error_sums, self.aux_error_counts
        del self.csv_file_handle, self.csv_writer

    def configure_optimizers(self):
        # --- Explicitly cast lr and weight_decay to float ---
        lr = float(getattr(self.hparams, 'learning_rate', 1e-4))
        weight_decay = float(getattr(self.hparams, 'weight_decay', 0.0001))
        
        parameters_to_optimize = []
        
        if self.hparams.training_stage == 'encoder_only':
            print("Configuring optimizer for: ENCODER_ONLY stage")
            parameters_to_optimize.extend(self.input_projection.parameters())
            if self.day_transform:
                parameters_to_optimize.extend(self.day_transform.parameters())
            parameters_to_optimize.extend(self.encoder.parameters())
            for head_name, head_module in self.aux_heads.items():
                if self.hparams.get(f'train_{head_name}_aux', False): # Check if this aux task is active
                    parameters_to_optimize.extend(head_module.parameters())
        
        elif self.hparams.training_stage in ['joint_teacher_forcing', 'joint_sequential_generation']:
            print(f"Configuring optimizer for: {self.hparams.training_stage} stage")
            # By default, train all parameters unless specific freezing is configured
            # Example: self.hparams.get('freeze_encoder_in_joint', False)
            if not self.hparams.get('freeze_encoder_in_joint', False):
                parameters_to_optimize.extend(self.input_projection.parameters())
                if self.day_transform:
                    parameters_to_optimize.extend(self.day_transform.parameters())
                parameters_to_optimize.extend(self.encoder.parameters())
            else:
                print("Encoder parts are FROZEN for joint training.")

            if not self.hparams.get('freeze_decoder_in_joint', False):
                parameters_to_optimize.extend(self.phoneme_embedding.parameters())
                parameters_to_optimize.extend(self.decoder.parameters())
                parameters_to_optimize.extend(self.output_projection.parameters())
            else:
                print("Decoder parts are FROZEN for joint training.")

            # Auxiliary heads in joint training (if configured)
            if self.hparams.get('train_aux_in_joint_tf', False) or \
               self.hparams.get('train_aux_in_joint_sg', False):
                for head_name, head_module in self.aux_heads.items():
                     if self.hparams.get(f'train_{head_name}_aux', False):
                        parameters_to_optimize.extend(head_module.parameters())
        elif self.hparams.training_stage == 'secondary_only':
            print("Configuring optimizer for: SECONDARY_ONLY stage")
            # Train ECoG encoder and the active secondary head (BART or Whisper).
            # Phoneme decoder is implicitly frozen as its parameters are not added.
            parameters_to_optimize.extend(filter(lambda p: p.requires_grad, self.input_projection.parameters()))
            if self.day_transform:
                parameters_to_optimize.extend(filter(lambda p: p.requires_grad, self.day_transform.parameters()))
            parameters_to_optimize.extend(filter(lambda p: p.requires_grad, self.encoder.parameters()))

            # Add auxiliary heads if they are meant to be trained in this stage
            if self.hparams.get('train_aux_in_secondary_only', False):
                for head_name, head_module in self.aux_heads.items():
                    if self.hparams.get(f'train_{head_name}_aux', False):
                        parameters_to_optimize.extend(head_module.parameters())
        else:
            raise ValueError(f"Unknown training_stage for optimizer: {self.hparams.training_stage}")

        # --- BART Text Decoder Parameters ---
        

        if self.bart_text_decoder_active and self.bart_model is not None:
            if hasattr(self, 'ecog_to_bart_hidden_projection') and isinstance(self.ecog_to_bart_hidden_projection, nn.Linear):
                parameters_to_optimize.extend(filter(lambda p: p.requires_grad, self.ecog_to_bart_hidden_projection.parameters()))
             
            if rank_zero_only.rank == 0: print("Adding BART model trainable parameters to optimizer.")
            # BART's parameters' requires_grad status is set in _apply_freezing_from_hparams
            parameters_to_optimize.extend(filter(lambda p: p.requires_grad, self.bart_model.parameters()))

        if self.whisper_text_decoder_active and self.whisper_model is not None:
            if hasattr(self, 'ecog_to_whisper_hidden_projection'):
                parameters_to_optimize.extend(filter(lambda p: p.requires_grad, self.ecog_to_whisper_hidden_projection.parameters()))

            if rank_zero_only.rank == 0: print("Adding Whisper model trainable parameters to optimizer.")
            parameters_to_optimize.extend(filter(lambda p: p.requires_grad, self.whisper_model.parameters()))

        if not parameters_to_optimize:
            raise ValueError(f"No parameters selected for optimization in stage '{self.hparams.training_stage}'. Check config and freezing flags.")

        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, parameters_to_optimize), lr=lr, weight_decay=weight_decay)
        
        # Optional: Learning rate scheduler
        if self.hparams.get('use_scheduler', False):
            print(F"SCHEDULER MONITOR: {self.hparams.get('scheduler_monitor', 'val_total_loss_epoch')}")

            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode=self.hparams.get('scheduler_mode', 'min'),
                factor=self.hparams.get('scheduler_factor', 0.1),
                patience=self.hparams.get('scheduler_patience', 10),
                min_lr=self.hparams.get('min_lr', 1e-6),
                threshold=self.hparams.get('scheduler_threshold', 1e-3)
            )
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": self.hparams.get('scheduler_monitor', 'val_total_loss_epoch'), # Adjust monitor metric
                },
            }
        return optimizer


    def _get_auxiliary_predictions(self, encoder_outputs, src_ecog_lens=None, memory_key_padding_mask=None):
        """
        Computes raw predictions from auxiliary heads without calculating losses.
        """
        aux_predictions_dict = {}
        if not hasattr(self, 'aux_heads') or not self.aux_heads:
            return aux_predictions_dict

        # Per-timestep auxiliary tasks
        if 'mfcc' in self.aux_heads:
            mfcc_preds_raw = self.aux_heads['mfcc'](encoder_outputs) # [B, T_enc, num_mfcc]
            aux_predictions_dict['mfcc'] = mfcc_preds_raw

        if 'envelope' in self.aux_heads:
            env_preds_raw = self.aux_heads['envelope'](encoder_outputs) # [B, T_enc, 1]
            aux_predictions_dict['envelope'] = env_preds_raw

        # Sequence-level auxiliary tasks
        pooled_encoder_output = None
        needs_pooling = ('phoneme_len' in self.aux_heads) or ('word_count' in self.aux_heads)

        if needs_pooling:
            # Determine padding mask for pooling
            current_padding_mask_for_pool = memory_key_padding_mask
            if current_padding_mask_for_pool is None and src_ecog_lens is not None:
                current_padding_mask_for_pool = self._generate_padding_mask(src_ecog_lens, encoder_outputs.size(1))
            
            if current_padding_mask_for_pool is not None:
                pool_mask = (~current_padding_mask_for_pool).unsqueeze(-1).float()
                sum_pool_mask = pool_mask.sum(dim=1).clamp(min=1e-9) # Avoid division by zero
                pooled_encoder_output = (encoder_outputs * pool_mask).sum(dim=1) / sum_pool_mask
            elif encoder_outputs.size(1) > 0: # Fallback: average pool if no mask/lengths but has content
                pooled_encoder_output = encoder_outputs.mean(dim=1)
            else: # Cannot pool if sequence length is 0
                pass # pooled_encoder_output remains None

        if pooled_encoder_output is not None:
            if 'phoneme_len' in self.aux_heads:
                # Raw output from the linear layer for phoneme_len
                phoneme_len_preds_unscaled = self.aux_heads['phoneme_len'](pooled_encoder_output) # [B, 1]
                # Apply scaling factor as done in training for consistency if this value is used elsewhere
                scaling_factor = self.hparams.get('phoneme_len_scaling_factor', 30.0)
                aux_predictions_dict['phoneme_len'] = phoneme_len_preds_unscaled * scaling_factor
                aux_predictions_dict['phoneme_len_unscaled'] = phoneme_len_preds_unscaled # Also store unscaled

            if 'word_count' in self.aux_heads:
                # Raw output from the linear layer for word_count
                word_count_preds_unscaled = self.aux_heads['word_count'](pooled_encoder_output) # [B, 1]
                scaling_factor = self.hparams.get('word_count_scaling_factor', 10.0)
                aux_predictions_dict['word_count'] = word_count_preds_unscaled * scaling_factor
                aux_predictions_dict['word_count_unscaled'] = word_count_preds_unscaled # Also store unscaled
        
        return aux_predictions_dict

    def _perform_beam_search_generation(self, memory, memory_key_padding_mask, max_len,
                                        beam_width, length_penalty_alpha,
                                        len_condition_vec=None, wc_condition_vec=None):
        """
        Performs beam search for a single sample (B=1).
        Returns padded logits [1, max_len-1, vocab_size] and padded ids [1, max_len].
        """
        B, T_mem, D_mem = memory.shape
        if B != 1:
            raise ValueError("_perform_beam_search_generation expects batch_size=1.")

        device = memory.device
        vocab_size = self.hparams.vocab_size
        sos_idx = self.hparams.sos_idx
        eos_idx = self.hparams.eos_idx
        pad_idx = self.hparams.pad_idx

        # Each beam: (sequence_tensor [1, current_len], cumulative_log_prob_score, list_of_step_logits)
        active_beams = [(torch.full((1, 1), sos_idx, dtype=torch.long, device=device), 0.0, [])]
        completed_hypotheses = []

        for step in range(max_len - 1): # Max_len includes SOS, so max_len-1 steps of generation
            if not active_beams:
                break

            all_candidates_for_step = []
            for current_seq, current_score, current_logits_hist in active_beams:
                # Prepare decoder input
                tgt_emb = self.phoneme_embedding(current_seq)
                if self.hparams.get('augment_decoder_with_phoneme_len', False) or \
                   self.hparams.get('augment_decoder_with_word_count', False):
                    if len_condition_vec is not None or wc_condition_vec is not None: # These are already [1, D_global_emb]
                        tgt_emb = self._augment_decoder_input_embeddings(tgt_emb, len_condition_vec, wc_condition_vec)
                
                tgt_emb = self.pos_decoder(tgt_emb * math.sqrt(self.hparams.d_model))
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(current_seq.size(1), device=device)

                decoder_output = self.decoder(
                    tgt_emb, memory, tgt_mask=tgt_mask,
                    memory_key_padding_mask=memory_key_padding_mask
                ) # [1, current_len, d_model]
                
                next_token_logits = self.output_projection(decoder_output[:, -1, :]) # [1, vocab_size]
                next_token_log_probs = F.log_softmax(next_token_logits, dim=-1) # [1, vocab_size]

                # Get top k candidates for this beam
                top_k_log_probs, top_k_indices = torch.topk(next_token_log_probs.squeeze(0), beam_width)

                for k_idx in range(top_k_indices.size(0)):
                    token_id = top_k_indices[k_idx].unsqueeze(0).unsqueeze(0) # [1,1]
                    log_prob = top_k_log_probs[k_idx].item()

                    new_seq = torch.cat([current_seq, token_id], dim=1)
                    new_score = current_score + log_prob
                    new_logits_hist = current_logits_hist + [next_token_logits] # Store [1, vocab_size]

                    all_candidates_for_step.append((new_seq, new_score, new_logits_hist))
            
            # Prune candidates from all beams for this step
            # Sort all candidates by score
            ordered_candidates = sorted(all_candidates_for_step, key=lambda x: x[1], reverse=True)
            
            active_beams = [] # Reset active beams for next step
            for cand_seq, cand_score, cand_logits_hist in ordered_candidates:
                if cand_seq[0, -1].item() == eos_idx:
                    # Apply length penalty: score / (sequence_length ** alpha)
                    # sequence_length for penalty usually doesn't include SOS.
                    seq_len_for_penalty = max(1, cand_seq.size(1) - 1) # Number of actual generated tokens
                    penalized_score = cand_score / (seq_len_for_penalty ** length_penalty_alpha)
                    completed_hypotheses.append((cand_seq, penalized_score, cand_logits_hist))
                else:
                    # This is an active beam, add to list for next step if we need more
                    if len(active_beams) < beam_width:
                        active_beams.append((cand_seq, cand_score, cand_logits_hist))
                
                # Optimization: if we have enough completed and active beams, we might prune further
                # For now, just collect up to beam_width active ones.
                if len(active_beams) >= beam_width and len(completed_hypotheses) >= beam_width: # Heuristic
                    break 
            
            # Ensure we don't exceed beam_width for active beams
            active_beams = active_beams[:beam_width]


        # If no completed hypotheses (e.g., max_len reached before EOS),
        # add current best active beams to completed_hypotheses with penalty.
        if not completed_hypotheses:
            for cand_seq, cand_score, cand_logits_hist in active_beams:
                seq_len_for_penalty = max(1, cand_seq.size(1) - 1)
                penalized_score = cand_score / (seq_len_for_penalty ** length_penalty_alpha)
                completed_hypotheses.append((cand_seq, penalized_score, cand_logits_hist))

        if not completed_hypotheses: # Still nothing (e.g., max_len=1, only SOS)
            # Return SOS and empty, padded logits
            final_ids = torch.full((1, max_len), pad_idx, dtype=torch.long, device=device)
            final_ids[0,0] = sos_idx
            final_logits = torch.zeros((1, max_len - 1, vocab_size), device=device, dtype=memory.dtype)
            return final_logits, final_ids

        # Select the best one from completed_hypotheses (highest penalized_score)
        best_hyp = sorted(completed_hypotheses, key=lambda x: x[1], reverse=True)[0]
        
        best_id_sequence = best_hyp[0]    # [1, GenLen]
        best_logits_history = best_hyp[2] # List of [1, VocabSize] tensors

        # Pad generated sequence to max_len
        current_len = best_id_sequence.size(1)
        if current_len < max_len:
            padding = torch.full((1, max_len - current_len), pad_idx, dtype=torch.long, device=device)
            final_ids = torch.cat([best_id_sequence, padding], dim=1)
        else:
            final_ids = best_id_sequence[:, :max_len]

        # Prepare and pad logits
        if best_logits_history:
            # Stack logits: list of [1, VocabSize] -> [1, ActualGenLen-1, VocabSize]
            stacked_logits = torch.cat(best_logits_history, dim=0).unsqueeze(0)
        else: # Only SOS was generated
            stacked_logits = torch.empty((1, 0, vocab_size), device=device, dtype=memory.dtype)

        num_actual_logit_steps = stacked_logits.size(1)
        if num_actual_logit_steps < max_len - 1:
            logit_padding_size = max_len - 1 - num_actual_logit_steps
            logit_padding = torch.zeros((1, logit_padding_size, vocab_size), device=device, dtype=stacked_logits.dtype)
            final_logits = torch.cat([stacked_logits, logit_padding], dim=1)
        elif num_actual_logit_steps > max_len - 1:
            final_logits = stacked_logits[:, :max_len - 1, :]
        else:
            final_logits = stacked_logits
            
        return final_logits, final_ids

    def predict(self, src_ecog, src_ecog_lens, days=None, mode="greedy", max_len=70,
                beam_width=5, length_penalty_alpha=0.6,
                bart_max_gen_len_predict=None, 
                bart_num_beams_predict=None,
                whisper_max_gen_len_predict=None,
                whisper_num_beams_predict=None,
                downsample_factor_override=None,
                visualize_bart_attention: bool = False):
        """
        Performs inference to generate phoneme sequences and optionally text.
        """
        self.eval()
        with torch.no_grad():
            B = src_ecog.size(0)
            device = src_ecog.device

            # --- Handle downsampling for prediction ---
            if downsample_factor_override is not None:
                # This is a temporary override for prediction time.
                # It does not change self.downsample_factor permanently.
                # The input_projection layer must be compatible with this override.
                # This is only safe if input_projection is `Unfold`, as it's dynamic.
                if self.hparams.get('feature_extractor_type') != 'unfold':
                    rank_zero_info("Warning: `downsample_factor_override` is used but feature extractor is not 'unfold'. This may lead to unexpected behavior if the extractor has a fixed factor.")
                
                # We need to manually apply unfold if overriding
                unfolder = Unfold(downsample_factor_override)
                src_ecog = unfolder(src_ecog)
                downsampled_src_ecog_lens_pred = torch.ceil(src_ecog_lens.float() / downsample_factor_override).long()
            else:
                # Use the model's default downsampling factor
                downsampled_src_ecog_lens_pred = self._get_downsampled_lens(src_ecog_lens)


            # 1. Encoder Forward (Raw ECoG Encoder Output)
            raw_ecog_encoder_outputs_pred, intermediate_for_aux_pred, memory_key_padding_mask = \
                self._encoder_forward(src_ecog, downsampled_src_ecog_lens_pred, days)

            # 2. Get Auxiliary Predictions (using intermediate encoder outputs)
            aux_predictions_dict_inf = self._get_auxiliary_predictions(
                intermediate_for_aux_pred, downsampled_src_ecog_lens_pred, memory_key_padding_mask
            )

            # 3. Prepare Encoder Representation for Phoneme Decoder (Potentially Augmented)
            final_encoder_representation_for_phoneme_decoder_pred = raw_ecog_encoder_outputs_pred
            if self.hparams.get('augment_encoder_with_mfcc', False) and \
               aux_predictions_dict_inf is not None and 'mfcc' in aux_predictions_dict_inf:
                predicted_mfccs_for_aug_pred = aux_predictions_dict_inf['mfcc']
                if predicted_mfccs_for_aug_pred is not None:
                    final_encoder_representation_for_phoneme_decoder_pred = self._augment_encoder_outputs(
                        raw_ecog_encoder_outputs_pred, # Augment the raw output
                        predicted_mfccs_for_aug_pred.to(raw_ecog_encoder_outputs_pred.device).detach()
                    )
            
            # 4. Prepare Global Conditioning Vectors for Phoneme Decoder
            len_condition_vec_inf = None
            wc_condition_vec_inf = None
            if aux_predictions_dict_inf:
                if self.hparams.get('augment_decoder_with_phoneme_len', False) and \
                   'phoneme_len' in aux_predictions_dict_inf and hasattr(self, 'len_embedder'):
                    predicted_len_scalar = aux_predictions_dict_inf['phoneme_len']
                    if predicted_len_scalar is not None:
                        len_condition_vec_inf = self.len_embedder(predicted_len_scalar.to(self.device).detach())

                if self.hparams.get('augment_decoder_with_word_count', False) and \
                   'word_count' in aux_predictions_dict_inf and hasattr(self, 'word_count_embedder'):
                    predicted_wc_scalar = aux_predictions_dict_inf['word_count']
                    if predicted_wc_scalar is not None:
                        wc_condition_vec_inf = self.word_count_embedder(predicted_wc_scalar.to(self.device).detach())

            results = {
                "predicted_phoneme_text": None, "predicted_phoneme_ids": None, "predicted_logits": None,
                "predicted_text_bart": None, "predicted_text_bart_ids": None,
                "predicted_text_whisper": None, "predicted_text_whisper_ids": None,
                "auxiliary_predictions": aux_predictions_dict_inf,
                "bart_cross_attention": None # Add key for storing attention
            }

            # 5.A Phoneme Generation
            generated_phoneme_ids = None
            generated_phoneme_logits = None

            if mode == "greedy":
                generated_phoneme_logits, generated_phoneme_ids = self._perform_sequential_generation(
                    final_encoder_representation_for_phoneme_decoder_pred, memory_key_padding_mask, max_len=max_len,
                    use_gumbel_softmax=False, 
                    len_condition_vec=len_condition_vec_inf, wc_condition_vec=wc_condition_vec_inf
                )
                # Pad logits if necessary
                num_actual_logits_steps = generated_phoneme_logits.size(1)
                if num_actual_logits_steps < max_len - 1:
                    logit_padding_size = max_len - 1 - num_actual_logits_steps
                    logit_padding = torch.zeros((B, logit_padding_size, self.hparams.vocab_size), device=device, dtype=generated_phoneme_logits.dtype)
                    generated_phoneme_logits = torch.cat([generated_phoneme_logits, logit_padding], dim=1)
                elif num_actual_logits_steps > max_len -1:
                    generated_phoneme_logits = generated_phoneme_logits[:, :max_len-1, :]

            elif mode == "beam_search":
                generated_phoneme_ids_list_pred_mode = [] # Renamed to avoid clash with outer scope
                generated_phoneme_logits_list_pred_mode = []
                for i in range(B): 
                    single_memory = final_encoder_representation_for_phoneme_decoder_pred[i:i+1] # Use potentially augmented
                    single_mem_pad_mask = memory_key_padding_mask[i:i+1] if memory_key_padding_mask is not None else None
                    single_len_cond = len_condition_vec_inf[i:i+1] if len_condition_vec_inf is not None else None
                    single_wc_cond = wc_condition_vec_inf[i:i+1] if wc_condition_vec_inf is not None else None

                    logits_i, ids_i = self._perform_beam_search_generation(
                        single_memory, single_mem_pad_mask, max_len, beam_width, length_penalty_alpha,
                        single_len_cond, single_wc_cond
                    )
                    generated_phoneme_ids_list_pred_mode.append(ids_i)
                    generated_phoneme_logits_list_pred_mode.append(logits_i)
                generated_phoneme_ids = torch.cat(generated_phoneme_ids_list_pred_mode, dim=0)
                generated_phoneme_logits = torch.cat(generated_phoneme_logits_list_pred_mode, dim=0)
            else:
                raise ValueError(f"Unsupported phoneme inference mode: {mode}.")

            results["predicted_phoneme_ids"] = generated_phoneme_ids
            results["predicted_logits"] = generated_phoneme_logits
            
            all_predicted_phoneme_text = []
            if generated_phoneme_ids is not None:
                for i in range(generated_phoneme_ids.size(0)):
                    phoneme_list = self._ids_to_phoneme_list(
                        generated_phoneme_ids[i], strip_sos=True, strip_pad=True, 
                        keep_eos_str=False, stop_at_eos=True
                    )
                    all_predicted_phoneme_text.append(" ".join(phoneme_list))
            results["predicted_phoneme_text"] = all_predicted_phoneme_text

            # 5.B BART Text Generation
            if self.bart_text_decoder_active and self.bart_model is not None and self.bart_tokenizer is not None:
                if rank_zero_only.rank == 0: print("Performing BART text generation in predict()...")
                eff_bart_max_gen_len = bart_max_gen_len_predict if bart_max_gen_len_predict is not None \
                                     else self.hparams.get('bart_max_gen_len_predict', 150)
                eff_bart_num_beams = bart_num_beams_predict if bart_num_beams_predict is not None \
                                   else self.hparams.get('bart_num_beams_predict', self.hparams.get('bart_val_num_beams', 4))

                # Project RAW ECoG encoder outputs for BART
                projected_raw_encoder_outputs_for_bart_predict = self.ecog_to_bart_hidden_projection(raw_ecog_encoder_outputs_pred)
                # This mask refers to the padding of our ECoG encoder outputs
                bart_predict_encoder_padding_mask = (~memory_key_padding_mask).int() if memory_key_padding_mask is not None else None

                # Wrap encoder_outputs for generate method
                wrapped_encoder_outputs_predict = BaseModelOutput(last_hidden_state=projected_raw_encoder_outputs_for_bart_predict.contiguous())

                bart_generate_outputs = self.bart_model.generate(
                    encoder_outputs=wrapped_encoder_outputs_predict, # Use wrapped outputs
                    attention_mask=bart_predict_encoder_padding_mask, 
                    max_length=eff_bart_max_gen_len, num_beams=eff_bart_num_beams, early_stopping=True,
                    decoder_start_token_id=self.bart_model.config.decoder_start_token_id,
                    # --- New arguments for attention ---
                    output_attentions=visualize_bart_attention,
                    return_dict_in_generate=True
                )
                generated_bart_ids_predict = bart_generate_outputs.sequences
                results["predicted_text_bart_ids"] = generated_bart_ids_predict.cpu()
                
                # --- Process and store attention if requested ---
                if visualize_bart_attention and 'cross_attentions' in bart_generate_outputs and bart_generate_outputs.cross_attentions is not None:
                    # `cross_attentions` is a tuple of tuples. Outer is over generated tokens, inner is over decoder layers.
                    # Each element is (B, H, 1, T_src). We want the last layer's attention for all tokens.
                    last_layer_attentions = [token_attentions[-1] for token_attentions in bart_generate_outputs.cross_attentions]
                    if last_layer_attentions:
                        # Concatenate along the target sequence length dimension (which is 1 for each token)
                        # The shape becomes (B, H, T_tgt, T_src)
                        bart_attention_matrix = torch.cat(last_layer_attentions, dim=2)
                        results["bart_cross_attention"] = bart_attention_matrix.cpu()

                predicted_bart_text_list = self.bart_tokenizer.batch_decode(
                    generated_bart_ids_predict, skip_special_tokens=True, 
                    clean_up_tokenization_spaces=True
                )
                results["predicted_text_bart"] = predicted_bart_text_list
                if rank_zero_only.rank == 0 and predicted_bart_text_list:
                     print(f"BART predicted text example: {predicted_bart_text_list[0]}")
            
            # 5.C Whisper Text Generation
            if self.whisper_text_decoder_active and self.whisper_model is not None and self.whisper_processor is not None:
                if rank_zero_only.rank == 0: print("Performing Whisper text generation in predict()...")
                eff_whisper_max_gen_len = whisper_max_gen_len_predict if whisper_max_gen_len_predict is not None \
                                     else self.hparams.get('whisper_max_gen_len_predict', 150)
                eff_whisper_num_beams = whisper_num_beams_predict if whisper_num_beams_predict is not None \
                                   else self.hparams.get('whisper_num_beams_predict', self.hparams.get('whisper_val_num_beams', 4))

                projected_raw_encoder_outputs_for_whisper_predict = self.ecog_to_whisper_hidden_projection(raw_ecog_encoder_outputs_pred)
                wrapped_encoder_outputs_predict = BaseModelOutput(last_hidden_state=projected_raw_encoder_outputs_for_whisper_predict.contiguous())
                
                # Create attention mask for prediction
                whisper_predict_encoder_padding_mask = (~memory_key_padding_mask).long() if memory_key_padding_mask is not None else None

                # Get forced decoder IDs for English transcription task
                forced_decoder_ids_predict = self.whisper_processor.get_decoder_prompt_ids(language="en", task="transcribe")

                generated_whisper_ids_predict = self.whisper_model.generate(
                    encoder_outputs=wrapped_encoder_outputs_predict,
                    attention_mask=whisper_predict_encoder_padding_mask, # Add attention mask
                    max_length=eff_whisper_max_gen_len,
                    num_beams=eff_whisper_num_beams,
                    early_stopping=True,
                    forced_decoder_ids=forced_decoder_ids_predict # Use explicit forced decoder IDs
                )
                results["predicted_text_whisper_ids"] = generated_whisper_ids_predict.cpu()

                predicted_whisper_text_list = self.whisper_tokenizer.batch_decode(
                    generated_whisper_ids_predict, skip_special_tokens=True,
                    clean_up_tokenization_spaces=True
                )
                results["predicted_text_whisper"] = predicted_whisper_text_list
                if rank_zero_only.rank == 0 and predicted_whisper_text_list:
                    print(f"Whisper predicted text example: {predicted_whisper_text_list[0]}")
            
        return results


    def setup(self, stage=None):
        # Called by PyTorch Lightning. Good place for initial weight loading.
        if stage == 'fit' or stage is None: # Ensure it runs before training starts
            if self.hparams.get('initial_encoder_checkpoint_path'):
                print(f"Loading initial ENCODER weights from: {self.hparams.initial_encoder_checkpoint_path}")
                ckpt = torch.load(self.hparams.initial_encoder_checkpoint_path, map_location=self.device)
                state_dict = ckpt.get('state_dict', ckpt) # Handle raw state_dict or PL checkpoint
                
                encoder_weights = {}
                # Define prefixes for encoder components
                encoder_prefixes = ['input_projection.', 'day_transform.', 'encoder.']
                # Add aux_heads prefixes if they were part of the encoder checkpoint
                for head_name in self.aux_heads.keys():
                    encoder_prefixes.append(f'aux_heads.{head_name}.')

                for k, v in state_dict.items():
                    for prefix in encoder_prefixes:
                        if k.startswith(prefix):
                            encoder_weights[k] = v
                            break
                if encoder_weights:
                    self.load_state_dict(encoder_weights, strict=False)
                    print(f"Loaded {len(encoder_weights)} keys for encoder components.")
                else:
                    print("WARNING: No encoder weights found or matched in the provided checkpoint.")


            if self.hparams.get('initial_decoder_checkpoint_path'):
                print(f"Loading initial DECODER weights from: {self.hparams.initial_decoder_checkpoint_path}")
                # This checkpoint is assumed to be from StandaloneTextDecoder
                ckpt = torch.load(self.hparams.initial_decoder_checkpoint_path, map_location=self.device)
                state_dict = ckpt.get('state_dict', ckpt)

                decoder_weights = {}
                # Mapping from StandaloneTextDecoder to NeuralToPhonemeTransformer
                # Assumes StandaloneTextDecoder had 'phoneme_embedding', 'decoder', 'output_projection'
                # And their internal structure matches nn.TransformerDecoder etc.
                # Example: if standalone had 'embedding.' use that.
                # This mapping might need adjustment based on actual StandaloneTextDecoder structure.
                key_map = { 
                    'phoneme_embedding.': 'phoneme_embedding.', # if standalone has this exact name
                    'decoder.': 'decoder.',                     # if standalone has this exact name
                    'output_projection.': 'output_projection.'  # if standalone has this exact name
                }
                # If standalone model has different top-level names, adjust key_map source keys
                # e.g., if standalone has self.embedding, self.transformer_decoder, self.fc_out
                # key_map = { 'embedding.': 'phoneme_embedding.', 'transformer_decoder.': 'decoder.', 'fc_out.': 'output_projection.'}


                for k_ckpt, v_ckpt in state_dict.items():
                    for k_standalone_prefix, k_full_model_prefix in key_map.items():
                        if k_ckpt.startswith(k_standalone_prefix):
                            new_key = k_full_model_prefix + k_ckpt[len(k_standalone_prefix):]
                            decoder_weights[new_key] = v_ckpt
                            break
                
                if decoder_weights:
                    self.load_state_dict(decoder_weights, strict=False)
                    print(f"Loaded {len(decoder_weights)} keys for decoder components.")
                else:
                    print("WARNING: No decoder weights found or matched in the provided checkpoint. Check key_map.")
        for name, param in self.named_parameters():
            if "aux_heads.phoneme_len" in name:
                print(name, param.requires_grad)
