import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import RNNTLoss
from edit_distance import SequenceMatcher  # Make sure edit_distance library is installed
from augmentations import GaussianSmoothing
import pytorch_lightning as pl  # Make sure PyTorch Lightning is installed
from dataset import idsToPhonemes, idToPhone
from torch.nn.utils.rnn import pad_sequence


def rnnt_autoregressive_inference(encoder_output, prediction_network, joiner, max_output_length=100, blank=0):
    """
    Autoregressive greedy decoding for RNN-T model.
    
    Args:
        encoder_output (Tensor): Output from the encoder, shape (1, T, hidden_dim).
        prediction_network (nn.Module): The prediction network.
        joiner (nn.Module): The joiner network.
        max_output_length (int): Maximum length of the output sequence.
        blank (int): Index of the blank token.
        
    Returns:
        List[int]: Decoded sequence of phoneme indices.
    """
    device = encoder_output.device
    batch_size, T, _ = encoder_output.size()
    
    # Initialize the output sequence and starting token
    decoded_sequence = []
    prev_tokens = torch.zeros((1, 1), dtype=torch.long, device=device)  # Start with a blank token or initial zero

    t = 0  # Encoder time step
    u = 0  # Output sequence length counter

    while t < T and len(decoded_sequence) < max_output_length:
        # Get the prediction network's output based on previous tokens
        pred_out = prediction_network(prev_tokens)  # (1, u+1, hidden_dim)
        
        # Get logits from the joiner combining encoder and predictor outputs
        logits = joiner(encoder_output[:, t:t+1], pred_out)  # Shape: (1, 1, u+2, vocab_size)
        
        # Apply softmax and get the most probable token (greedy decoding)
        probs = torch.log_softmax(logits, dim=-1)
        predicted_token = torch.argmax(probs[:, 0, -1, :], dim=-1).item()

        if predicted_token == blank:
            # If blank token is predicted, move to the next time step
            t += 1
        else:
            # Add the predicted token to the output sequence
            decoded_sequence.append(predicted_token)

            # Update previous tokens for the next prediction
            prev_tokens = torch.cat([prev_tokens, torch.tensor([[predicted_token]], device=device)], dim=1)

        u += 1  # Increment the output sequence length counter

    return decoded_sequence



def rnnt_greedy_decode(logits, vocab_size, blank=0):
    """
    Greedy decoder for RNN-T model outputs.
    
    Args:
        logits (Tensor): The output logits from the joiner network.
                         Shape: (batch, T, U+1, vocab_size)
        vocab_size (int): Total size of the vocabulary (including blank).
        blank (int): Index of the blank token.

    Returns:
        List[List[int]]: Decoded sequences for each batch.
    """
    batch_size, T, U_plus_1, _ = logits.size()
    decoded_sequences = []
    
    # Softmax over the vocabulary to get probabilities
    probs = torch.log_softmax(logits, dim=-1)

    for batch_idx in range(batch_size):
        t, u = 0, 0
        sequence = []
        while t < T:
            # Get the argmax over the vocabulary at current position (t, u)
            current_probs = probs[batch_idx, t, u]
            predicted_token = torch.argmax(current_probs).item()
            
            if predicted_token == blank:
                # If blank, move to the next time step (advance in encoder time)
                t += 1
            else:
                # If not blank, append token to sequence and move to the next target position
                sequence.append(predicted_token)
                u += 1
                # Stop decoding if we've exhausted U+1
                if u >= U_plus_1 - 1:
                    break

        decoded_sequences.append(sequence)
    
    return decoded_sequences, probs



class TransformerPredictionNetwork(nn.Module):
    """
    Transformer-based Prediction Network: Processes previous phoneme predictions.
    """
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, dropout):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, 1000, hidden_dim))  # Max sequence length 1000
        self.transformer_layers = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=hidden_dim,
                nhead=num_heads,
                dim_feedforward=4 * hidden_dim,
                dropout=dropout,
                activation="gelu"
            ),
            num_layers=num_layers
        )
        self.proj = nn.Linear(hidden_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, y):
        batch_size, U = y.shape

        # Ensure `y` has valid indices
        assert (y >= 0).all() and (y < self.embed.num_embeddings).all(), "Invalid token index in input!"

        # Embedding and positional encoding
        y = self.embed(y) + self.pos_embedding[:, :U, :]

        # Causal mask (prevents attending to future tokens)
        causal_mask = torch.triu(torch.ones(U, U, device=y.device), diagonal=1)
        causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf'))  # Convert to float

        # Dummy memory tensor (to satisfy TransformerDecoder)
        dummy_memory = torch.zeros((1, batch_size, self.embed.embedding_dim), device=y.device)

        # Transformer Decoder
        y = self.transformer_layers(y.transpose(0, 1), memory=dummy_memory, tgt_mask=causal_mask).transpose(0, 1)

        return self.layer_norm(self.proj(y))


class RNNPredictionNetwork(nn.Module):
    """
    Prediction Network: Processes previous phoneme predictions.
    Note: It does NOT prepend a blank token – the Joiner will handle that.
    """
    def __init__(self, vocab_size, hidden_dim, num_layers, dropout):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.rnn = nn.LSTM(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout
        )
        self.proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, y):
        # y: (batch, U) where U is the target length (without blank)
        y = self.embed(y)            # (batch, U, hidden_dim)
        y, _ = self.rnn(y)           # (batch, U, hidden_dim)
        return self.proj(y)          # (batch, U, hidden_dim)

class RNNTJoiner(nn.Module):
    """
    Joiner: Combines outputs of encoder and prediction network.
    It prepends a blank slot (all zeros) to the prediction output.
    """
    def __init__(self, hidden_dim, output_dim, bidirectional=False,activation: str = "relu"):
        """
        hidden_dim: dimension of predictor output  
        output_dim: final output dimension (typically n_classes+1, with index 0 reserved for blank)
        """
        super().__init__()
        # The encoder output has dimension (hidden_dim * 2) if bidirectional.
        self.proj_enc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        if activation == "relu":
            self.activation = torch.nn.ReLU()
        elif activation == "tanh":
            self.activation = torch.nn.Tanh()
        else:
            raise ValueError(f"Unsupported activation {activation}")


    def forward(self, enc_out, pred_out):
        """
        enc_out: (batch, T, hidden_dim * 2)  if bidirectional; else (batch, T, hidden_dim)
        pred_out: (batch, U, hidden_dim)  from the predictor network.
        """
        enc_out = self.proj_enc(enc_out)  # (batch, T, hidden_dim)
        # Expand dimensions for broadcasting
        enc_out = enc_out.unsqueeze(2).contiguous()    # (batch, T, 1, hidden_dim)
        pred_out = pred_out.unsqueeze(1).contiguous()  # (batch, 1, U, hidden_dim)

        # Create blank token slot by prepending an all-zeros embedding
        # (Using zeros is common; alternatively one can use a learnable blank embedding.)
        blank = torch.zeros_like(pred_out[:, :, :1, :])  # (batch, 1, 1, hidden_dim)
        # Now pred_out_with_blank has shape (batch, 1, U+1, hidden_dim)
        pred_out_with_blank = torch.cat([blank, pred_out], dim=2)

        # Sum encoder and prediction embeddings (broadcasting: T x (U+1))
        combined = enc_out + pred_out_with_blank  # (batch, T, U+1, hidden_dim)

        # Apply activation
        combined = self.activation(combined)

        # Map to logits over the vocabulary
        logits = self.fc(combined)  # (batch, T, U+1, output_dim)
        return logits

class LightningRNNTDecoder(pl.LightningModule):
    def __init__(
        self,
        neural_dim,
        n_classes,
        hidden_dim,
        layer_dim,
        nDays=24,
        dropout=0.1,
        strideLen=4,
        kernelLen=14,
        gaussianSmoothWidth=0,
        bidirectional=False,
        learning_rate=1e-3,
        white_noise_SD=0.01,
        constant_offset_SD=0.01,
        weight_decay=1e-5,
        smoothing=True,
        day_transforms=True,
        mfcc_dim=14,
        normalize_rnnt_loss = True
    ):
        super().__init__()

        self.neural_dim = neural_dim
        self.n_classes = n_classes
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.bidirectional = bidirectional
        self.learning_rate = learning_rate
        self.strideLen = strideLen
        self.kernelLen = kernelLen
        self.white_noise_SD = white_noise_SD
        self.constant_offset_SD = constant_offset_SD
        self.weight_decay = weight_decay
        self.smoothing = smoothing
        self.day_transforms = day_transforms
        self.gaussianSmoothWidth = gaussianSmoothWidth  
        self.mfcc_dim=mfcc_dim
        self.normalize_rnnt_loss = normalize_rnnt_loss

        self.inputLayerNonlinearity = nn.Softsign()
        self.unfolder = nn.Unfold((self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen)
        self.gaussianSmoother = GaussianSmoothing(neural_dim, 20, self.gaussianSmoothWidth, dim=1)

        # Day-specific transformations (if used later)
        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))
        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))
        for x in range(nDays):
            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)

        # **Encoder (GRU)**
        self.gru_decoder = nn.GRU(
            neural_dim * self.kernelLen,
            hidden_dim,
            layer_dim,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional,
        )
        for name, param in self.gru_decoder.named_parameters():
            if "weight_hh" in name:
                nn.init.orthogonal_(param)
            if "weight_ih" in name:
                nn.init.xavier_uniform_(param)

        # **Prediction Network**
        self.prediction_network = RNNPredictionNetwork(
            vocab_size=n_classes + 1,  # +1 for blank
            hidden_dim=hidden_dim,
            num_layers=1,
            dropout=dropout
        )

        # **Joiner**
        self.joiner = RNNTJoiner(hidden_dim, n_classes + 1, bidirectional=bidirectional)

        # **RNN-T Loss**
        self.rnnt_loss = RNNTLoss(blank=0, reduction="mean")

        self.mfcc_unfolder = nn.Unfold((self.strideLen, 1), dilation=1, padding=0, stride=self.strideLen)
        self.mfcc_decoder = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, mfcc_dim*self.strideLen) 
        # Loss function
        self.ctc_loss = nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)
        self.l1oss = nn.L1Loss()

    def get_neural_embeddings(self, neuralInput, dayIdx):
        """
        Forward pass of the model.
        neuralInput: (batch, time, features)
        dayIdx: Session index
        """

        if self.smoothing:
        
            neuralInput = torch.permute(neuralInput, (0, 2, 1))
            neuralInput = self.gaussianSmoother(neuralInput)
            neuralInput = torch.permute(neuralInput, (0, 2, 1))

        if self.day_transforms:
            # Apply day-specific transformations
            dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)
            transformedNeural = torch.einsum("btd,bdk->btk", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)
            transformedNeural = self.inputLayerNonlinearity(transformedNeural)
        else:
            transformedNeural = self.inputLayerNonlinearity(neuralInput)


        # Apply unfolding (sliding window)
        stridedInputs = torch.permute(
            self.unfolder(torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)), (0, 2, 1)
        )

        # Initialize GRU hidden state
        h0 = torch.zeros(
            self.layer_dim * (2 if self.bidirectional else 1),
            transformedNeural.size(0),
            self.hidden_dim,
            device=self.device
        ).requires_grad_()

        # Apply GRU
        hid, _ = self.gru_decoder(stridedInputs, h0.detach())

        return hid


    def forward(self, neuralInput, dayIdx,target):
        """
        neuralInput: (batch, T, neural_dim)
        target: (batch, U) with token indices (non-padded length indicated separately)
        """

        enc_out = self.get_neural_embeddings(neuralInput, dayIdx)
        
        # If bidirectional, enc_out is (batch, L, hidden_dim*2)
        # Otherwise, (batch, L, hidden_dim)

        # Prediction network forward pass.
        pred_out = self.prediction_network(target)  # (batch, U, hidden_dim)

        # Debug print of shapes.
        # print("DEBUG: Encoder output shape:", enc_out.shape, "(should be (batch, T_eff, hidden_dim*2 if bidirectional))")
        # print("DEBUG: Predictor output shape:", pred_out.shape, "(should be (batch, U, hidden_dim))")

        # Joiner forward pass produces logits of shape (batch, T_eff, U+1, vocab_size)
        logits = self.joiner(enc_out, pred_out)
        # print("DEBUG: Logits shape:", logits.shape, "(should be (batch, T_eff, U+1, vocab_size))")

        return logits

    def training_step(self, batch, batch_idx):
        """
        Training step - Computes RNN-T loss.
        Batch is assumed to contain:
          X: neural input (batch, T, neural_dim)
          y: target tokens (batch, U)
          X_len: original neural sequence lengths (before any unfolding)
          y_len: target lengths (non-padded)
          dayIdx: day index (if used)
          _: (possibly other info)
        """


        X = batch["neural_feats"]
        y = batch["phone_seq"]
        X_len = batch["neural_time_bins"]
        y_len = batch["phone_seq_len"]
        dayIdx = batch["day"]
        sentence = batch["sentence"]
        MFCC = batch["mfcc"]

        #unfold MFCC

        # MFCC = torch.permute(MFCC, (0, 2, 1))

        MFCC = pad_sequence([torch.tensor(i) for i in batch["mfcc"]], batch_first=True)
        MFCC = MFCC.to(self.device)

        MFCC = torch.permute(self.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))


        # X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # ---- TRIM TARGETS TO THE MAXIMUM TRUE LENGTH IN THE BATCH ----
        max_target = y_len.max().item()      # e.g. 70
        y_trim = y[:, :max_target].contiguous()            # Now y_trim is [B, max_target]

        # Noise augmentation
        if self.white_noise_SD > 0:
            X += torch.randn_like(X) * self.white_noise_SD
        if self.constant_offset_SD > 0:
            X += torch.randn(X.size(0), 1, X.size(2), device=self.device) * self.constant_offset_SD

        # Forward pass
        pred = self.forward(X, dayIdx,y_trim)

        # IMPORTANT: Compute effective encoder output lengths.
        # nn.Unfold produces L = floor((T - kernelLen)/strideLen) + 1 windows.
        effective_T = ((X_len - self.kernelLen) // self.strideLen) + 1

        # Compute RNN-T Loss.  
        # pred: (batch, T_eff, U+1, vocab_size)
        # y: (batch, U)
        # effective_T: (batch,) and y_len: (batch,)
        # loss = self.rnnt_loss(pred, y, effective_T, y_len)
        loss = self.rnnt_loss(pred.contiguous(), y_trim, effective_T, y_len)

        if self.normalize_rnnt_loss:
            loss = loss / len(pred) #divide for batch size

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation step - Computes loss and (a simplified) CER.
        """
        
        X = batch["neural_feats"]
        y = batch["phone_seq"]
        X_len = batch["neural_time_bins"]
        y_len = batch["phone_seq_len"]
        dayIdx = batch["day"]
        sentence = batch["sentence"]
        MFCC = batch["mfcc"]

        #unfold MFCC

        # MFCC = torch.permute(MFCC, (0, 2, 1))

        MFCC = pad_sequence([torch.tensor(i) for i in batch["mfcc"]], batch_first=True)
        MFCC = MFCC.to(self.device)

        MFCC = torch.permute(self.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))


        # X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # Trim targets as in training_step
        max_target = y_len.max().item()
        y_trim = y[:, :max_target].contiguous()

        pred = self.forward(X, dayIdx, y_trim)  
        # Compute effective encoder output lengths.
        effective_T = ((X_len - self.kernelLen) // self.strideLen) + 1
        # Make logits contiguous

        pred = pred.contiguous()
        
        # print("DEBUG effective T", effective_T)
        # print("DEBUG ylen",y_len.shape,  y_len)
        
        loss = self.rnnt_loss(pred, y_trim, effective_T, y_len)
        if self.normalize_rnnt_loss:
            loss = loss / len(pred) #divide for batch size

        decoded, _ = rnnt_greedy_decode(pred, self.n_classes + 1, blank=0) 

        print(decoded)
        # Evaluate CER (Character Error Rate)
        total_edit_distance, total_seq_length = 0, 0
        for i in range(len(decoded)):
            decoded_seq = decoded[i]

            # Convert tensors to integers and remove blanks (index 0)
            decoded_seq = [int(token) for token in decoded_seq if int(token) != 0]

            # Ensure indices are within bounds before mapping to phonemes
            decoded_seq = [token for token in decoded_seq if token < len(idToPhone) + 1]

            true_seq = y[i][: y_len[i]].cpu().tolist()

            # Compute edit distance
            matcher = SequenceMatcher(a=true_seq, b=decoded_seq)
            total_edit_distance += matcher.distance()
            total_seq_length += len(true_seq)

            if i == 0:
                # Only print for the first sequence for debugging
                print("True:", idsToPhonemes(true_seq))
                print("Predicted:", idsToPhonemes(decoded_seq))

        # Compute and log CER
        cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_CER", cer, prog_bar=True, on_epoch=True)
        return loss

   
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, eps=1e-8)
        return optimizer



    def infer(self, neuralInput, dayIdx, max_output_length=100):
        """
        Perform autoregressive inference on neural data.
        
        Args:
            neuralInput (Tensor): Neural data of shape (1, T, neural_dim).
            dayIdx (Tensor): Day index for session-specific transformations.
            max_output_length (int): Maximum output sequence length.
        
        Returns:
            List[int]: Predicted phoneme sequence.
        """
        self.eval()  # Ensure the model is in evaluation mode
        
        with torch.no_grad():
            # Get encoder output from neural data
            encoder_output = self.get_neural_embeddings(neuralInput, dayIdx)
            
            # Perform autoregressive greedy decoding
            decoded_sequence = rnnt_autoregressive_inference(
                encoder_output,
                self.prediction_network,
                self.joiner,
                max_output_length=max_output_length,
                blank=0  # Index for the blank token
            )
        
        return decoded_sequence
    














############### PRETRAINED ###########



class LightningPretrainedRNNTDecoder(pl.LightningModule):
    def __init__(
        self,
        neural_dim,
        n_classes,
        hidden_dim,
        dropout,
        bidirectional,
        encoder,
        learning_rate=1e-3,
        white_noise_SD=0.01,
        constant_offset_SD=0.01,
        weight_decay=1e-5,
        smoothing=True,
        day_transforms=True,
        num_LM_layers=2,
        prediction_type="rnn", #or ["rnn,"transformer"]
        freeze_encoder=True,
        normalize_rnnt_loss=True

    ):
        super().__init__()

        self.encoder = encoder
        self.normalize_rnnt_loss = normalize_rnnt_loss

        #get some attributes from the encoder
        self.kernelLen = self.encoder.kernelLen
        self.strideLen = self.encoder.strideLen
        self.gaussianSmoothWidth = self.encoder.gaussianSmoothWidth
        self.bidirectional = self.encoder.bidirectional
        self.layer_dim = self.encoder.layer_dim

        self.neural_dim = neural_dim
        self.n_classes = n_classes
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
        self.white_noise_SD = white_noise_SD
        self.constant_offset_SD = constant_offset_SD
        self.weight_decay = weight_decay
        self.smoothing = smoothing
        self.day_transforms = day_transforms
        self.freeze_encoder = freeze_encoder

        self.inputLayerNonlinearity = nn.Softsign()
        self.unfolder = nn.Unfold((self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen)
        self.gaussianSmoother = GaussianSmoothing(neural_dim, 20, self.gaussianSmoothWidth, dim=1)
        self.prediction_type = prediction_type
        self.num_LM_layers = num_LM_layers
        self.mfcc_decoder = nn.Linear(self.encoder.hidden_dim * 2 if self.encoder.bidirectional else self.encoder.hidden_dim, self.encoder.mfcc_dim*self.strideLen) 


        if self.freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        

        
        if self.prediction_type == "transformer":
            self.prediction_network = TransformerPredictionNetwork(
                vocab_size=n_classes + 1,  # +1 for blank
                hidden_dim=hidden_dim,
                num_layers=num_LM_layers,
                num_heads=4,
                dropout=dropout)

        elif self.prediction_type == "rnn":
            # **Prediction Network**
            self.prediction_network = RNNPredictionNetwork(
                vocab_size=n_classes + 1,  # +1 for blank
                hidden_dim=hidden_dim,
                num_layers=num_LM_layers,
                dropout=dropout
            )

        # **Joiner**
        self.joiner = RNNTJoiner(hidden_dim, n_classes + 1, bidirectional=bidirectional)

        # **RNN-T Loss**
        self.rnnt_loss = RNNTLoss(blank=0, reduction="mean")

    def get_neural_embeddings(self, neuralInput, dayIdx):
        return self.encoder.get_neural_embedding(neuralInput, dayIdx)


    def forward(self, neuralInput, dayIdx,target):
        """
        neuralInput: (batch, T, neural_dim)
        target: (batch, U) with token indices (non-padded length indicated separately)
        """

        enc_out = self.get_neural_embeddings(neuralInput, dayIdx)
        
        # If bidirectional, enc_out is (batch, L, hidden_dim*2)
        # Otherwise, (batch, L, hidden_dim)

        # Prediction network forward pass.
        pred_out = self.prediction_network(target)  # (batch, U, hidden_dim)

        # Debug print of shapes.
        # print("DEBUG: Encoder output shape:", enc_out.shape, "(should be (batch, T_eff, hidden_dim*2 if bidirectional))")
        # print("DEBUG: Predictor output shape:", pred_out.shape, "(should be (batch, U, hidden_dim))")

        # Joiner forward pass produces logits of shape (batch, T_eff, U+1, vocab_size)
        logits = self.joiner(enc_out, pred_out)
        # print("DEBUG: Logits shape:", logits.shape, "(should be (batch, T_eff, U+1, vocab_size))")

        return logits

    def training_step(self, batch, batch_idx):
        """
        Training step - Computes RNN-T loss.
        Batch is assumed to contain:
          X: neural input (batch, T, neural_dim)
          y: target tokens (batch, U)
          X_len: original neural sequence lengths (before any unfolding)
          y_len: target lengths (non-padded)
          dayIdx: day index (if used)
          _: (possibly other info)
        """
        X = batch["neural_feats"]
        y = batch["phone_seq"]
        X_len = batch["neural_time_bins"]
        y_len = batch["phone_seq_len"]
        dayIdx = batch["day"]
        sentence = batch["sentence"]
        MFCC = batch["mfcc"]

        #unfold MFCC

        # MFCC = torch.permute(MFCC, (0, 2, 1))

        MFCC = pad_sequence([torch.tensor(i) for i in batch["mfcc"]], batch_first=True)
        MFCC = MFCC.to(self.device)

        MFCC = torch.permute(self.encoder.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))


        # X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # ---- TRIM TARGETS TO THE MAXIMUM TRUE LENGTH IN THE BATCH ----
        max_target = y_len.max().item()      # e.g. 70
        y_trim = y[:, :max_target].contiguous()            # Now y_trim is [B, max_target]

        # Noise augmentation
        if self.white_noise_SD > 0:
            X += torch.randn_like(X) * self.white_noise_SD
        if self.constant_offset_SD > 0:
            X += torch.randn(X.size(0), 1, X.size(2), device=self.device) * self.constant_offset_SD

        # Forward pass
        pred = self.forward(X, dayIdx,y_trim)

        # IMPORTANT: Compute effective encoder output lengths.
        # nn.Unfold produces L = floor((T - kernelLen)/strideLen) + 1 windows.
        effective_T = ((X_len - self.kernelLen) // self.strideLen) + 1

        # Compute RNN-T Loss.  
        # pred: (batch, T_eff, U+1, vocab_size)
        # y: (batch, U)
        # effective_T: (batch,) and y_len: (batch,)
        # loss = self.rnnt_loss(pred, y, effective_T, y_len)
        loss = self.rnnt_loss(pred.contiguous(), y_trim, effective_T, y_len)
        if self.normalize_rnnt_loss:
            loss = loss / len(pred) #divide for batch size

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation step - Computes loss and (a simplified) CER.
        """
        X = batch["neural_feats"]
        y = batch["phone_seq"]
        X_len = batch["neural_time_bins"]
        y_len = batch["phone_seq_len"]
        dayIdx = batch["day"]
        sentence = batch["sentence"]
        MFCC = batch["mfcc"]

        #unfold MFCC

        # MFCC = torch.permute(MFCC, (0, 2, 1))

        MFCC = pad_sequence([torch.tensor(i) for i in batch["mfcc"]], batch_first=True)
        MFCC = MFCC.to(self.device)

        MFCC = torch.permute(self.encoder.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))


        # X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # Trim targets as in training_step
        max_target = y_len.max().item()
        y_trim = y[:, :max_target].contiguous()

        pred = self.forward(X, dayIdx, y_trim)  
        # Compute effective encoder output lengths.
        effective_T = ((X_len - self.kernelLen) // self.strideLen) + 1
        # Make logits contiguous

        pred = pred.contiguous()
        
        # print("DEBUG effective T", effective_T)
        # print("DEBUG ylen",y_len.shape,  y_len)
        
        loss = self.rnnt_loss(pred, y_trim, effective_T, y_len)
        if self.normalize_rnnt_loss:
            loss = loss / len(pred) #divide for batch size

        decoded, _ = rnnt_greedy_decode(pred, self.n_classes + 1, blank=0) ## THIS DOEESNT EXIST!

        # Evaluate CER (Character Error Rate)
        total_edit_distance, total_seq_length = 0, 0
        for i in range(len(decoded)):
            decoded_seq = decoded[i]

            # Convert tensors to integers and remove blanks (index 0)
            decoded_seq = [int(token) for token in decoded_seq if int(token) != 0]

            # Ensure indices are within bounds before mapping to phonemes
            decoded_seq = [token for token in decoded_seq if token < len(idToPhone) + 1]

            true_seq = y[i][: y_len[i]].cpu().tolist()

            # Compute edit distance
            matcher = SequenceMatcher(a=true_seq, b=decoded_seq)
            total_edit_distance += matcher.distance()
            total_seq_length += len(true_seq)

            if i == 0:
                # Only print for the first sequence for debugging
                print("True:", idsToPhonemes(true_seq))
                print("Predicted:", idsToPhonemes(decoded_seq))

        # Compute and log CER
        cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_CER", cer, prog_bar=True, on_epoch=True)
        return loss

   
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, eps=1e-8)
        return optimizer



    def infer(self, neuralInput, dayIdx, max_output_length=100):
        """
        Perform autoregressive inference on neural data.
        
        Args:
            neuralInput (Tensor): Neural data of shape (1, T, neural_dim).
            dayIdx (Tensor): Day index for session-specific transformations.
            max_output_length (int): Maximum output sequence length.
        
        Returns:
            List[int]: Predicted phoneme sequence.
        """
        self.eval()  # Ensure the model is in evaluation mode
        
        with torch.no_grad():
            # Get encoder output from neural data
            encoder_output = self.get_neural_embeddings(neuralInput, dayIdx)
            
            # Perform autoregressive greedy decoding
            decoded_sequence = rnnt_autoregressive_inference(
                encoder_output,
                self.prediction_network,
                self.joiner,
                max_output_length=max_output_length,
                blank=0  # Index for the blank token
            )
        
        return decoded_sequence