import torch
from torch import nn

from augmentations import GaussianSmoothing
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import ReduceLROnPlateau
from edit_distance import SequenceMatcher
import math
from transformers import AutoProcessor, ClapModel, AutoModel, AutoTokenizer
import numpy as np
from collections import Counter
import torch.nn as nn
import torch.nn.functional as F
from transformers import MambaConfig, MambaModel



def random_monotonic_warp_curve(knot: int = 4, sigma: float = 0.2, device='cpu'):
    """
    Generate a random, monotonic piecewise-linear warp curve f: [0,1]->[0,1].
    The curve starts at 0 and ends at 1, with 'knot' control points in between,
    plus random perturbations of amplitude ~ sigma.

    Returns
    -------
    control_x : torch.Tensor of shape (knot,)  in ascending order from 0 to 1
    warp_curve : torch.Tensor of shape (knot,) in ascending order from 0 to 1
    """
    # 'knot' control positions in [0,1]
    control_x = torch.linspace(0, 1, steps=knot, device=device)
    
    # Generate random offsets around a roughly ascending baseline
    # We'll use random increments so the final function is guaranteed monotonic
    # Start with positive baseline increments, plus some random noise
    increments = torch.full((knot-1,), fill_value=1.0/(knot-1), device=device)  # baseline to go from 0..1
    increments += sigma * (torch.rand(knot-1, device=device) - 0.5)  # random +/- noise
    # Make sure increments stay positive (avoid negative slopes).
    # Here we just clamp to a minimum small value
    increments = torch.clamp(increments, min=1e-4)
    
    # Now build the warp curve by cumulative sum
    warp_curve = torch.cat([torch.tensor([0.0], device=device), torch.cumsum(increments, dim=0)], dim=0)
    
    # Normalize so that warp_curve[-1] = 1 exactly
    if warp_curve[-1] < 1e-8:
        # extremely unlikely, fallback to a trivial linear 0->1
        warp_curve = control_x.clone()
    else:
        warp_curve = warp_curve / warp_curve[-1]
    
    # In principle, warp_curve[0] should be 0, warp_curve[-1] should be 1
    warp_curve[0]  = 0.0
    warp_curve[-1] = 1.0
    
    return control_x, warp_curve


def time_warp_variable_length(
    X: torch.Tensor,
    new_seq_len: int,
    knot: int = 4,
    sigma: float = 0.2
) -> torch.Tensor:
    """
    Perform random nonlinear time warping on each sequence in X,
    changing the sequence length to `new_seq_len`.
    
    The warp is monotonic but can accelerate or decelerate different parts.
    
    Parameters
    ----------
    X : torch.Tensor
        Shape (BS, seq_len, channel_dim).
    new_seq_len : int
        Desired output sequence length.
    knot : int
        Number of control points for the warp curve. More -> more wiggles.
    sigma : float
        Noise amplitude for random slope offsets.
        
    Returns
    -------
    warped_X : torch.Tensor
        Shape (BS, new_seq_len, channel_dim).
    """
    assert X.ndim == 3, "X must have shape (BS, seq_len, channel_dim)."
    BS, seq_len, channel_dim = X.shape
    
    device = X.device
    
    # We'll produce a new tensor: (BS, new_seq_len, channel_dim)
    warped_X = torch.zeros((BS, new_seq_len, channel_dim), dtype=X.dtype, device=device)
    
    # For each sample in the batch, we do:
    # 1) Build a random monotonic warp curve in [0,1] -> [0,1]
    # 2) For each new time index i in [0..(new_seq_len-1)], we find
    #    the corresponding old time index by "inverse lookup" of the warp.
    # 3) Interpolate the old data at that old time index.
    
    # Precompute "old_times" array for interpolation (0..seq_len-1)
    old_times = torch.arange(seq_len, device=device, dtype=torch.float32)
    
    for b in range(BS):
        # Generate random monotonic warp curve for this sample
        control_x, warp_curve = random_monotonic_warp_curve(knot, sigma, device=device)
        
        # Now we want to invert warp_curve so for a new_time in [0..1],
        # we figure out the corresponding old_time in [0..1].
        # Because (control_x, warp_curve) is piecewise linear, we can do
        # an inverse via a piecewise linear search.
        
        # We'll sample the new time steps in normalized form: t_new = i/(new_seq_len-1)
        # Then we find t_old = warp^{-1}(t_new).
        
        # We can do CPU-based or GPU-based. We'll do a quick CPU approach with np.interp.
        t_new_arr = np.linspace(0, 1, new_seq_len, endpoint=True)
        
        # We have warp_curve as f(control_x). We want the inverse.
        # i.e. for each t_new, find t_old so that warp_curve(t_old) = t_new.
        # We'll use piecewise linear inverse by "search sorted" approach or np.interp.
        
        # Because warp_curve is monotonic, we can do:
        # t_old_arr = np.interp(t_new_arr, warp_curve.cpu().numpy(), control_x.cpu().numpy())
        t_old_arr = np.interp(t_new_arr, warp_curve.detach().cpu().numpy(), control_x.detach().cpu().numpy())
        
        # t_old_arr is now in [0,1], shape (new_seq_len,)
        # We'll map that to old index domain [0..(seq_len-1)]
        old_idx_arr = t_old_arr * (seq_len - 1)
        
        # Next, for each channel, we interpolate from old_times -> old values
        X_b = X[b].detach().cpu().numpy()  # shape (seq_len, channel_dim)
        
        out_b = np.zeros((new_seq_len, channel_dim), dtype=X_b.dtype)
        
        for c in range(channel_dim):
            out_b[:, c] = np.interp(
                x=old_idx_arr,
                xp=old_times.cpu().numpy(),
                fp=X_b[:, c]
            )
        
        # Convert out_b to torch and assign to warped_X[b]
        warped_X[b] = torch.from_numpy(out_b).to(device)
    
    return warped_X

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super().__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j, normalize=True):
        if normalize:
            z_i = F.normalize(z_i, dim=1)
            z_j = F.normalize(z_j, dim=1)

        assert z_i.shape == z_j.shape, f"z_i shape {z_i.shape}, z_j shape {z_j.shape}"

        # Compute cosine similarity between all pairs (batch, batch)
        logits = torch.matmul(z_i, z_j.T) / self.temperature  # shape (N, N)
        labels = torch.arange(z_i.size(0), device=z_i.device)

        # Check logits shape
        assert logits.shape == (z_i.shape[0], z_i.shape[0]), f"logits shape {logits.shape}, expected ({z_i.shape[0]}, {z_i.shape[0]})"

        # CrossEntropy expects (N, C) logits and (N,) target
        loss_i_j = F.cross_entropy(logits, labels)
        loss_j_i = F.cross_entropy(logits.T, labels)
        return 0.5 * (loss_i_j + loss_j_i)


class WarmupCosineAnnealingLR(torch.optim.lr_scheduler.LambdaLR):
    def __init__(self, optimizer, warmup_steps, total_steps, eta_min=0):
        """
        Custom LR scheduler with linear warmup followed by cosine annealing.

        Args:
            optimizer (torch.optim.Optimizer): Wrapped optimizer.
            warmup_steps (int): Number of steps for linear warmup.
            total_steps (int): Total number of training steps.
            eta_min (float): Minimum learning rate at the end of decay.
        """
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.eta_min = eta_min

        super().__init__(optimizer, self.lr_lambda)

    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return step / self.warmup_steps  # Linear warmup
        else:
            decay_ratio = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            cosine_decay = 0.5 * (1 + math.cos(math.pi * decay_ratio))
            return cosine_decay * (1 - self.eta_min) + self.eta_min


class SimpleGRUDecoder(pl.LightningModule):
    def __init__(
        self,
        neural_dim,      # Input feature dimension
        n_classes,       # Number of phoneme classes (+1 for CTC blank)
        hidden_dim=128,  # Hidden size of GRU
        layer_dim=2,     # Number of GRU layers
        dropout=0.1,     # Dropout rate
        bidirectional=True,
        learning_rate=1e-3,
        weight_decay=0,
    ):
        super(SimpleGRUDecoder, self).__init__()
        
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay    
        self.neural_dim = neural_dim
        self.n_classes = n_classes
        
        # 1D Convolution for feature extraction
        self.conv1d = nn.Conv1d(neural_dim, hidden_dim, kernel_size=5, stride=1, padding=2)
        self.batch_norm = nn.BatchNorm1d(hidden_dim)
        self.activation = nn.ReLU()
        
        # GRU-based Recurrent Decoder
        self.gru = nn.GRU(
            hidden_dim, hidden_dim, num_layers=layer_dim,
            batch_first=True, dropout=dropout, bidirectional=bidirectional
        )

        # Linear output layer
        gru_output_dim = hidden_dim * (2 if bidirectional else 1)
        self.fc = nn.Linear(gru_output_dim, n_classes + 1)  # +1 for CTC blank token

        # Loss function (CTC loss)
        self.ctc_loss = nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, time_steps, neural_dim)
        """
        # Reshape for 1D CNN
        x = x.permute(0, 2, 1)  # (batch, features, time)
        x = self.conv1d(x)      # Apply 1D CNN
        x = self.batch_norm(x)
        x = self.activation(x)
        x = x.permute(0, 2, 1)  # Restore shape for GRU (batch, time, features)

        # Apply GRU
        x, _ = self.gru(x)

        # Output layer
        x = self.fc(x)  # (batch, time, n_classes)
        return x

    def training_step(self, batch, batch_idx):
        X, y, X_len, y_len,days, sentence = batch
        X, y, X_len, y_len = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device)

        logits = self.forward(X)  # (batch, time, n_classes)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        # Compute CTC Loss
        loss = self.ctc_loss(
            log_probs.permute(1, 0, 2),  # (time, batch, classes)
            y, X_len, y_len
        )

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y, X_len, y_len,days, sentence = batch
        X, y, X_len, y_len = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device)

        logits = self.forward(X)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        
        loss = self.ctc_loss(
            log_probs.permute(1, 0, 2),  # (time, batch, classes)
            y, X_len, y_len
        )

        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate,weight_decay=self.weight_decay,
                                     betas = (0.9, 0.999), eps = 0.1)
        scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}



class GRUDecoder(nn.Module):
    def __init__(
        self,
        neural_dim,
        n_classes,
        hidden_dim,
        layer_dim,
        nDays=24,
        dropout=0,
        device="cuda",
        strideLen=4,
        kernelLen=14,
        gaussianSmoothWidth=0,
        bidirectional=False,
    ):
        super(GRUDecoder, self).__init__()

        # Defining the number of layers and the nodes in each layer
        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim
        self.neural_dim = neural_dim
        self.n_classes = n_classes
        self.nDays = nDays
        self.device = device
        self.dropout = dropout
        self.strideLen = strideLen
        self.kernelLen = kernelLen
        self.gaussianSmoothWidth = gaussianSmoothWidth
        self.bidirectional = bidirectional
        self.inputLayerNonlinearity = torch.nn.Softsign()
        self.unfolder = torch.nn.Unfold(
            (self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen
        )
        self.gaussianSmoother = GaussianSmoothing(
            neural_dim, 20, self.gaussianSmoothWidth, dim=1
        )
        self.dayWeights = torch.nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))
        self.dayBias = torch.nn.Parameter(torch.zeros(nDays, 1, neural_dim))

        for x in range(nDays):
            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)

        # GRU layers
        self.gru_decoder = nn.GRU(
            (neural_dim) * self.kernelLen,
            hidden_dim,
            layer_dim,
            batch_first=True,
            dropout=self.dropout,
            bidirectional=self.bidirectional,
        )

        for name, param in self.gru_decoder.named_parameters():
            if "weight_hh" in name:
                nn.init.orthogonal_(param)
            if "weight_ih" in name:
                nn.init.xavier_uniform_(param)

        # Input layers
        for x in range(nDays):
            setattr(self, "inpLayer" + str(x), nn.Linear(neural_dim, neural_dim))

        for x in range(nDays):
            thisLayer = getattr(self, "inpLayer" + str(x))
            thisLayer.weight = torch.nn.Parameter(
                thisLayer.weight + torch.eye(neural_dim)
            )

        # rnn outputs
        if self.bidirectional:
            self.fc_decoder_out = nn.Linear(
                hidden_dim * 2, n_classes + 1
            )  # +1 for CTC blank
        else:
            self.fc_decoder_out = nn.Linear(hidden_dim, n_classes + 1)  # +1 for CTC blank

    def forward(self, neuralInput, dayIdx):
        neuralInput = torch.permute(neuralInput, (0, 2, 1))
        neuralInput = self.gaussianSmoother(neuralInput)
        neuralInput = torch.permute(neuralInput, (0, 2, 1))

        # apply day layer
        dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)
        transformedNeural = torch.einsum(
            "btd,bdk->btk", neuralInput, dayWeights
        ) + torch.index_select(self.dayBias, 0, dayIdx)
        transformedNeural = self.inputLayerNonlinearity(transformedNeural)

        # stride/kernel
        stridedInputs = torch.permute(
            self.unfolder(
                torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)
            ),
            (0, 2, 1),
        )

        # apply RNN layer
        if self.bidirectional:
            h0 = torch.zeros(
                self.layer_dim * 2,
                transformedNeural.size(0),
                self.hidden_dim,
                device=self.device,
            ).requires_grad_()
        else:
            h0 = torch.zeros(
                self.layer_dim,
                transformedNeural.size(0),
                self.hidden_dim,
                device=self.device,
            ).requires_grad_()

        hid, _ = self.gru_decoder(stridedInputs, h0.detach())

        # get seq
        seq_out = self.fc_decoder_out(hid)
        return seq_out



##LIGHTNING 

class LightningGRUDecoder(pl.LightningModule):
    def __init__(
        self,
        neural_dim,
        n_classes,
        hidden_dim,
        layer_dim,
        nDays=24,
        dropout=0.1,
        strideLen=4,
        kernelLen=14,
        gaussianSmoothWidth=0,
        bidirectional=False,
        learning_rate=1e-3,
        white_noise_SD=0.01,
        constant_offset_SD=0.01,
        weight_decay=1e-5,
        smoothing =True,
        day_transforms = True,
        warmup_steps = 1000,
        total_steps = 10000,
    ):
        super().__init__()

        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim
        self.neural_dim = neural_dim
        self.n_classes = n_classes
        self.nDays = nDays
        self.strideLen = strideLen
        self.kernelLen = kernelLen
        self.gaussianSmoothWidth = gaussianSmoothWidth
        self.bidirectional = bidirectional
        self.learning_rate = learning_rate
        self.white_noise_SD = white_noise_SD
        self.constant_offset_SD = constant_offset_SD
        self.weight_decay = weight_decay
        self.smoothing = smoothing
        self.day_transforms = day_transforms
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps

        self.inputLayerNonlinearity = nn.Softsign()
        self.unfolder = nn.Unfold((self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen)
        self.gaussianSmoother = GaussianSmoothing(neural_dim, 20, self.gaussianSmoothWidth, dim=1)

        # Per-day transformation weights
        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))
        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))

        for x in range(nDays):
            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)

        # GRU layer
        self.gru_decoder = nn.GRU(
            (neural_dim) * self.kernelLen,
            hidden_dim,
            layer_dim,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional,
        )

        for name, param in self.gru_decoder.named_parameters():
            if "weight_hh" in name:
                nn.init.orthogonal_(param)
            if "weight_ih" in name:
                nn.init.xavier_uniform_(param)

        # Fully connected output layer
        self.fc_decoder_out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, n_classes + 1)  # +1 for CTC blank

        # Loss function
        self.ctc_loss = nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)


    def get_neural_embeddings(self, neuralInput, dayIdx):
        """
        Forward pass of the model.
        neuralInput: (batch, time, features)
        dayIdx: Session index
        """

        if self.smoothing:
        
            neuralInput = torch.permute(neuralInput, (0, 2, 1))
            neuralInput = self.gaussianSmoother(neuralInput)
            neuralInput = torch.permute(neuralInput, (0, 2, 1))

        if self.day_transforms:
            # Apply day-specific transformations
            dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)
            transformedNeural = torch.einsum("btd,bdk->btk", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)
            transformedNeural = self.inputLayerNonlinearity(transformedNeural)
        else:
            transformedNeural = self.inputLayerNonlinearity(neuralInput)


        # Apply unfolding (sliding window)
        stridedInputs = torch.permute(
            self.unfolder(torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)), (0, 2, 1)
        )

        # Initialize GRU hidden state
        h0 = torch.zeros(
            self.layer_dim * (2 if self.bidirectional else 1),
            transformedNeural.size(0),
            self.hidden_dim,
            device=self.device
        ).requires_grad_()

        # Apply GRU
        hid, _ = self.gru_decoder(stridedInputs, h0.detach())

        return hid

    def forward(self, neuralInput, dayIdx):
        """
        Forward pass of the model.
        neuralInput: (batch, time, features)
        dayIdx: Session index
        """

        if self.smoothing:
        
            neuralInput = torch.permute(neuralInput, (0, 2, 1))
            neuralInput = self.gaussianSmoother(neuralInput)
            neuralInput = torch.permute(neuralInput, (0, 2, 1))

        if self.day_transforms:
            # Apply day-specific transformations
            dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)
            transformedNeural = torch.einsum("btd,bdk->btk", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)
            transformedNeural = self.inputLayerNonlinearity(transformedNeural)
        else:
            transformedNeural = self.inputLayerNonlinearity(neuralInput)


        # Apply unfolding (sliding window)
        stridedInputs = torch.permute(
            self.unfolder(torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)), (0, 2, 1)
        )

        # Initialize GRU hidden state
        h0 = torch.zeros(
            self.layer_dim * (2 if self.bidirectional else 1),
            transformedNeural.size(0),
            self.hidden_dim,
            device=self.device
        ).requires_grad_()

        # Apply GRU
        hid, _ = self.gru_decoder(stridedInputs, h0.detach())

        # Final output layer
        return self.fc_decoder_out(hid)


    def training_step(self, batch, batch_idx):
        """
        Training step - Runs forward pass, computes loss, and returns it for backprop.
        """
        X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # Noise augmentation
        if self.white_noise_SD > 0:
            X += torch.randn(X.shape, device=self.device) * self.white_noise_SD
        if self.constant_offset_SD > 0:
            X += torch.randn([X.shape[0], 1, X.shape[2]], device=self.device) * self.constant_offset_SD

        # Forward pass
        pred = self.forward(X, dayIdx)

        # Compute CTC Loss
        loss = self.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
            y_len,
        )

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation step - Computes loss and CER.
        """
        X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        pred = self.forward(X, dayIdx)

        # Compute CTC Loss
        loss = self.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
            y_len,
        )

        # Compute CER (Phoneme Error Rate)
        total_edit_distance, total_seq_length = 0, 0
        for i in range(pred.shape[0]):
            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / self.strideLen), :], dim=-1)
            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)
            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()

            trueSeq = y[i][:y_len[i]].cpu().numpy()
            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())
            total_edit_distance += matcher.distance()
            total_seq_length += len(trueSeq)

        cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_CER", cer, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.99),
                                      eps=1e-8,) #eps was 0.1

        scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=3)
        # scheduler = WarmupCosineAnnealingLR(
        #     optimizer,
        #     warmup_steps=self.warmup_steps,
        #     total_steps=self.total_steps,
        #     eta_min=1e-6  # Smallest LR
        # )
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}




########### TIME WARPING

class LightningTimeWarpingGRUDecoder(pl.LightningModule):
    def __init__(
        self,
        neural_dim,
        n_classes,
        hidden_dim,
        layer_dim,
        nDays=24,
        dropout=0.1,
        strideLen=4,
        kernelLen=14,
        gaussianSmoothWidth=0,
        bidirectional=False,
        learning_rate=1e-3,
        white_noise_SD=0.01,
        constant_offset_SD=0.01,
        weight_decay=1e-5,
        smoothing =True,
        day_transforms = True,
        warmup_steps = 1000,
        total_steps = 10000,
        time_warp_factor = 0.2,
    ):
        super().__init__()

        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim
        self.neural_dim = neural_dim
        self.n_classes = n_classes
        self.nDays = nDays
        self.strideLen = strideLen
        self.kernelLen = kernelLen
        self.gaussianSmoothWidth = gaussianSmoothWidth
        self.bidirectional = bidirectional
        self.learning_rate = learning_rate
        self.white_noise_SD = white_noise_SD
        self.constant_offset_SD = constant_offset_SD
        self.weight_decay = weight_decay
        self.smoothing = smoothing
        self.day_transforms = day_transforms
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.time_warp_factor = time_warp_factor
        self.inputLayerNonlinearity = nn.Softsign()
        self.unfolder = nn.Unfold((self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen)
        self.gaussianSmoother = GaussianSmoothing(neural_dim, 20, self.gaussianSmoothWidth, dim=1)

        # Per-day transformation weights
        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))
        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))

        for x in range(nDays):
            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)

        # GRU layer
        self.gru_decoder = nn.GRU(
            (neural_dim) * self.kernelLen,
            hidden_dim,
            layer_dim,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional,
        )

        for name, param in self.gru_decoder.named_parameters():
            if "weight_hh" in name:
                nn.init.orthogonal_(param)
            if "weight_ih" in name:
                nn.init.xavier_uniform_(param)

        self.layer_norm= nn.LayerNorm(hidden_dim * 2 if bidirectional else hidden_dim)
        self.time_pooling = nn.Conv1d(hidden_dim * 2 if bidirectional else hidden_dim,hidden_dim * 2 if bidirectional else hidden_dim, kernel_size =5, stride=1, padding="same")

        # Fully connected output layer
        self.fc_decoder_out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, n_classes + 1)  # +1 for CTC blank

        # Loss function
        self.ctc_loss = nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)


    def get_neural_embeddings(self, neuralInput, dayIdx):
        """
        Forward pass of the model.
        neuralInput: (batch, time, features)
        dayIdx: Session index
        """

        if self.smoothing:
        
            neuralInput = torch.permute(neuralInput, (0, 2, 1))
            neuralInput = self.gaussianSmoother(neuralInput)
            neuralInput = torch.permute(neuralInput, (0, 2, 1))

        if self.day_transforms:
            # Apply day-specific transformations
            dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)
            transformedNeural = torch.einsum("btd,bdk->btk", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)
            transformedNeural = self.inputLayerNonlinearity(transformedNeural)
        else:
            transformedNeural = self.inputLayerNonlinearity(neuralInput)


        # Apply unfolding (sliding window)
        stridedInputs = torch.permute(
            self.unfolder(torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)), (0, 2, 1)
        )

        # Initialize GRU hidden state
        h0 = torch.zeros(
            self.layer_dim * (2 if self.bidirectional else 1),
            transformedNeural.size(0),
            self.hidden_dim,
            device=self.device
        ).requires_grad_()

        # Apply GRU
        hid, _ = self.gru_decoder(stridedInputs, h0.detach())

        #norm 
        hid = self.layer_norm(hid)

        #time pooling
        hid = torch.permute(hid, (0, 2, 1))
        hid = self.time_pooling(hid)
        hid = torch.permute(hid, (0, 2, 1))

        return hid

    def forward(self, neuralInput, dayIdx):
        """
        Forward pass of the model.
        neuralInput: (batch, time, features)
        dayIdx: Session index
        """

        if self.smoothing:
        
            neuralInput = torch.permute(neuralInput, (0, 2, 1))
            neuralInput = self.gaussianSmoother(neuralInput)
            neuralInput = torch.permute(neuralInput, (0, 2, 1))

        if self.day_transforms:
            # Apply day-specific transformations
            dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)
            transformedNeural = torch.einsum("btd,bdk->btk", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)
            transformedNeural = self.inputLayerNonlinearity(transformedNeural)
        else:
            transformedNeural = self.inputLayerNonlinearity(neuralInput)

        hid = self.get_neural_embeddings(transformedNeural, dayIdx)
        # Final output layer
        return self.fc_decoder_out(hid)


    def training_step(self, batch, batch_idx):
        """
        Training step - Runs forward pass, computes loss, and returns it for backprop.
        """
        X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)



        # print(X.shape)
        # print(X_len)
        # Noise augmentation
        if self.white_noise_SD > 0:
            X += torch.randn(X.shape, device=self.device) * self.white_noise_SD
        if self.constant_offset_SD > 0:
            X += torch.randn([X.shape[0], 1, X.shape[2]], device=self.device) * self.constant_offset_SD

        # Forward pass
        pred = self.forward(X, dayIdx)

        # Compute CTC Loss
        loss = self.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
            y_len,
        )

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation step - Computes loss and CER.
        """
        X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        pred = self.forward(X, dayIdx)

        # Compute CTC Loss
        loss = self.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
            y_len,
        )

        # Compute CER (Phoneme Error Rate)
        total_edit_distance, total_seq_length = 0, 0
        for i in range(pred.shape[0]):
            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / self.strideLen), :], dim=-1)
            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)
            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()

            trueSeq = y[i][:y_len[i]].cpu().numpy()
            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())
            total_edit_distance += matcher.distance()
            total_seq_length += len(trueSeq)

        cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_CER", cer, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.99),
                                      eps=1e-8,) #eps was 0.1

        scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=3)
        # scheduler = WarmupCosineAnnealingLR(
        #     optimizer,
        #     warmup_steps=self.warmup_steps,
        #     total_steps=self.total_steps,
        #     eta_min=1e-6  # Smallest LR
        # )
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}











######################





class ContrastiveGRUSentenceDecoder(pl.LightningModule):
    def __init__(
        self,
        neural_dim,
        latent_dim,
        hidden_dim,
        layer_dim,
        nDays=24,
        dropout=0.1,
        strideLen=4,
        kernelLen=14,
        gaussianSmoothWidth=0,
        bidirectional=False,
        learning_rate=1e-3,
        white_noise_SD=0.01,
        constant_offset_SD=0.01,
        weight_decay=1e-5,
        smoothing =True,
        day_transforms = True,
        text_encoder = None,
        text_tokenizer = None,
        temperature = 0.1,

    ):
        super().__init__()

        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim
        self.neural_dim = neural_dim
        self.latent_dim = latent_dim
        self.nDays = nDays
        self.strideLen = strideLen
        self.kernelLen = kernelLen
        self.gaussianSmoothWidth = gaussianSmoothWidth
        self.bidirectional = bidirectional
        self.learning_rate = learning_rate
        self.white_noise_SD = white_noise_SD
        self.constant_offset_SD = constant_offset_SD
        self.weight_decay = weight_decay
        self.smoothing = smoothing
        self.day_transforms = day_transforms

        self.text_encoder = text_encoder
        self.text_tokenizer = text_tokenizer

        self.temperature = temperature

        ## set text_encoder to eval mode and freeze
        # self.text_encoder.eval()
        # for param in self.text_encoder.parameters():
        #     param.requires_grad = False
        

        self.inputLayerNonlinearity = nn.Softsign()
        self.unfolder = nn.Unfold((self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen)
        self.gaussianSmoother = GaussianSmoothing(neural_dim, 20, self.gaussianSmoothWidth, dim=1)

        # Per-day transformation weights
        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))
        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))

        for x in range(nDays):
            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)

        # GRU layer
        self.gru_decoder = nn.GRU(
            (neural_dim) * self.kernelLen,
            hidden_dim,
            layer_dim,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional,
        )

        for name, param in self.gru_decoder.named_parameters():
            if "weight_hh" in name:
                nn.init.orthogonal_(param)
            if "weight_ih" in name:
                nn.init.xavier_uniform_(param)

        # Fully connected output layer
        # self.fc_decoder_out = nn.LazyLinear(latent_dim) 
        self.fc_decoder_out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, latent_dim)  

        # Loss function
        self.ctc_loss = nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)
        self.contrastive_loss = ContrastiveLoss(temperature=temperature)

    # def contrastive_loss(self, z_i, z_j, normalize =True):
    #     if normalize:
    #         z_i = nn.functional.normalize(z_i, dim=1)
    #         z_j = nn.functional.normalize(z_j, dim=1)
        
    #     logits = (z_i @ z_j.T) / self.temperature
    #     similarities = z_j @ z_j.T
    #     # targets = torch.nn.functional.softmax(similarities * self.temperature, dim=-1)

    #     targets = torch.arange(logits.shape[0]).long().to(logits.device)
        
    #     loss = 0.5*torch.nn.functional.cross_entropy(logits, targets) + 0.5*torch.nn.functional.cross_entropy(logits.T, targets)
    #     return loss


    def forward(self, neuralInput, dayIdx, sentence):
        """
        Forward pass of the model.
        neuralInput: (batch, time, features)
        dayIdx: Session index
        """





        if self.smoothing:
        
            neuralInput = torch.permute(neuralInput, (0, 2, 1))
            neuralInput = self.gaussianSmoother(neuralInput)
            neuralInput = torch.permute(neuralInput, (0, 2, 1))

        if self.day_transforms:
            # Apply day-specific transformations
            dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)
            transformedNeural = torch.einsum("btd,bdk->btk", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)
            transformedNeural = self.inputLayerNonlinearity(transformedNeural)
        else:
            transformedNeural = self.inputLayerNonlinearity(neuralInput)


        # Apply unfolding (sliding window)
        stridedInputs = torch.permute(
            self.unfolder(torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)), (0, 2, 1)
        )

        # Initialize GRU hidden state
        h0 = torch.zeros(
            self.layer_dim * (2 if self.bidirectional else 1),
            transformedNeural.size(0),
            self.hidden_dim,
            device=self.device
        ).requires_grad_()

        # Apply GRU
        hid, c = self.gru_decoder(stridedInputs, h0.detach())

        inputs = self.text_tokenizer(text=sentence, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        sentence_embedding = self.text_encoder.get_text_features(**inputs)

        neural_embeddings = self.fc_decoder_out(hid)

        ## question how to I pool the neural data? -> option one, average over windows
        neural_embeddings = torch.mean(neural_embeddings, dim=1)

        #take the last hidden state

        # neural_embeddings = neural_embeddings[:, -1, :]


        # Final output layer
        return neural_embeddings, sentence_embedding


    def training_step(self, batch, batch_idx):
        """
        Training step - Runs forward pass, computes loss, and returns it for backprop.
        """
        X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # Noise augmentation
        if self.white_noise_SD > 0:
            X += torch.randn(X.shape, device=self.device) * self.white_noise_SD
        if self.constant_offset_SD > 0:
            X += torch.randn([X.shape[0], 1, X.shape[2]], device=self.device) * self.constant_offset_SD

        # Forward pass
        neural_embeddings, sentence_embedding = self.forward(X, dayIdx, sentence)

        # compute contrastive loss
        loss = self.contrastive_loss(neural_embeddings, sentence_embedding, normalize =True)

        # # Compute CTC Loss
        # loss = self.ctc_loss(
        #     torch.permute(pred.log_softmax(2), [1, 0, 2]),
        #     y,
        #     ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
        #     y_len,
        # )

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation step - Computes loss and CER.
        """
        X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # Forward pass
        neural_embeddings, sentence_embedding = self.forward(X, dayIdx, sentence)

        # compute contrastive loss
        loss = self.contrastive_loss(neural_embeddings, sentence_embedding, normalize =True)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        
        return loss

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.999),eps = 1e-6)

        scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}










################ MULTIBRANCH  #################

##LIGHTNING 

class GRUBranch(nn.Module):
    """ code for a single branch of the multi-branch model """
    def __init__(
        self,
        neural_dim,
        n_classes,
        hidden_dim,
        layer_dim,
        nDays=24,
        dropout=0.1,
        strideLen=4,
        kernelLen=14,
        bidirectional=False,
        use_attention=True,
        fraction =1.,
    ):
        super(GRUBranch, self).__init__()
        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim
        self.neural_dim = neural_dim
        self.n_classes = n_classes
        self.nDays = nDays
        self.strideLen = strideLen
        self.kernelLen = kernelLen
        self.bidirectional = bidirectional
        self.dropout = dropout
        self.use_attention = use_attention
        self.fraction = fraction

        num_indices = int(self.fraction * self.neural_dim)
        self.branch_indices = torch.randperm(self.neural_dim)[:num_indices]
        
        self.unfolder = nn.Unfold((self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen)

        # GRU layer
        self.gru_decoder = nn.GRU(
            len(self.branch_indices)*self.kernelLen,
            hidden_dim,
            layer_dim,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional,
        )

        # Fully connected output layer
        self.fc_decoder_out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, n_classes + 1)  # +1 for CTC blank

        if self.use_attention:
            self.attention = nn.MultiheadAttention(hidden_dim * 2 if bidirectional else hidden_dim, 4, batch_first=True)

    def forward(self, transformedNeural):

        #select indices
        transformedNeural = transformedNeural[:, :, self.branch_indices].contiguous()

        # Apply unfolding (sliding window)
        stridedInputs = torch.permute(
            self.unfolder(torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)), (0, 2, 1)
        ).contiguous()

        # Initialize GRU hidden state
        h0 = torch.zeros(
            self.layer_dim * (2 if self.bidirectional else 1),
            transformedNeural.size(0),
            self.hidden_dim,
            device=transformedNeural.device
        ).requires_grad_()

        # Apply GRU
        hid, _ = self.gru_decoder(stridedInputs, h0.detach())

        if self.use_attention:
            hid, _ = self.attention(hid, hid, hid)
        out = self.fc_decoder_out(hid)

        

        # Final output layer
        return out

        



class LightningMultiBranchGRUDecoder(pl.LightningModule):
    def __init__(
        self,
        neural_dim,
        n_classes,
        hidden_dim,
        layer_dim,
        nDays=24,
        dropout=0.1,
        strideLens=(4,4,4),
        kernelLens=(32,32,32),
        gaussianSmoothWidth=0,
        bidirectional=False,
        learning_rate=1e-3,
        white_noise_SD=0.01,
        constant_offset_SD=0.01,
        weight_decay=1e-5,
        smoothing =True,
        day_transforms = True,
        warmup_steps = 1000,
        total_steps = 10000,
        use_attention = True,
        fraction = 0.8,
    ):
        super().__init__()

        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim
        self.neural_dim = neural_dim
        self.n_classes = n_classes
        self.nDays = nDays
        self.strideLens = strideLens
        self.kernelLens = kernelLens
        self.gaussianSmoothWidth = gaussianSmoothWidth
        self.bidirectional = bidirectional
        self.learning_rate = learning_rate
        self.white_noise_SD = white_noise_SD
        self.constant_offset_SD = constant_offset_SD
        self.weight_decay = weight_decay
        self.smoothing = smoothing
        self.day_transforms = day_transforms
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.fraction = fraction

        self.strideLen = self.strideLens[0]
        self.kernelLen = self.kernelLens[0]

        self.inputLayerNonlinearity = nn.Softsign()
        self.gaussianSmoother = GaussianSmoothing(neural_dim, 20, self.gaussianSmoothWidth, dim=1)

        # Per-day transformation weights
        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))
        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))

        for x in range(nDays):
            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)

        branches_list = [GRUBranch(neural_dim, 
                    n_classes, 
                    hidden_dim, 
                    layer_dim, 
                    nDays, 
                    dropout, 
                    strideLen=self.strideLens[0], 
                    kernelLen=self.kernelLens[0], 
                    bidirectional=self.bidirectional,
                    use_attention=use_attention, 
                    fraction=1.)]

        for i in range(1,len(self.strideLens)):
            branches_list.append(GRUBranch(neural_dim, 
                    n_classes, 
                    hidden_dim, 
                    layer_dim, 
                    nDays, 
                    dropout, 
                    strideLen=self.strideLens[i], 
                    kernelLen=self.kernelLens[i], 
                    bidirectional=self.bidirectional,
                    use_attention=use_attention, 
                    fraction=self.fraction))


        self.branches = nn.ModuleList(branches_list)


        # Loss function
        self.ctc_loss = nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)


    def forward(self, neuralInput, dayIdx):
        """
        Forward pass of the model.
        neuralInput: (batch, time, features)
        dayIdx: Session index
        """

        if self.smoothing:
        
            neuralInput = torch.permute(neuralInput, (0, 2, 1))
            neuralInput = self.gaussianSmoother(neuralInput)
            neuralInput = torch.permute(neuralInput, (0, 2, 1))

        if self.day_transforms:
            # Apply day-specific transformations
            dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)
            transformedNeural = torch.einsum("btd,bdk->btk", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)
            transformedNeural = self.inputLayerNonlinearity(transformedNeural)
        else:
            transformedNeural = self.inputLayerNonlinearity(neuralInput)

        hid = torch.stack([branch(transformedNeural) for branch in self.branches]).mean(0)  # List of outputs

        return hid


    def training_step(self, batch, batch_idx):
        """
        Training step - Runs forward pass, computes loss, and returns it for backprop.
        """
        X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # Noise augmentation
        if self.white_noise_SD > 0:
            X += torch.randn(X.shape, device=self.device) * self.white_noise_SD
        if self.constant_offset_SD > 0:
            X += torch.randn([X.shape[0], 1, X.shape[2]], device=self.device) * self.constant_offset_SD

        # Forward pass
        pred = self.forward(X, dayIdx)

        # Compute CTC Loss
        loss = self.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
            y_len,
        )

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation step - Computes loss and CER.
        """
        X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        pred = self.forward(X, dayIdx)

        # Compute CTC Loss
        loss = self.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
            y_len,
        )

        # Compute CER (Phoneme Error Rate)
        total_edit_distance, total_seq_length = 0, 0
        for i in range(pred.shape[0]):
            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / self.strideLen), :], dim=-1)
            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)
            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()

            trueSeq = y[i][:y_len[i]].cpu().numpy()
            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())
            total_edit_distance += matcher.distance()
            total_seq_length += len(trueSeq)

        cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_CER", cer, prog_bar=True, on_epoch=True)
        return loss


    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.99),
                                      eps=0.1,)

        scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)
        # scheduler = WarmupCosineAnnealingLR(
        #     optimizer,
        #     warmup_steps=self.warmup_steps,
        #     total_steps=self.total_steps,
        #     eta_min=1e-6  # Smallest LR
        # )
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}



class LightningGRUDecoder_MFCC_v3(pl.LightningModule):
    def __init__(
        self,
        neural_dim,
        n_classes,
        hidden_dim,
        layer_dim,
        nDays=24,
        dropout=0.1,
        strideLen=4,
        kernelLen=14,
        gaussianSmoothWidth=0,
        bidirectional=False,
        learning_rate=1e-3,
        white_noise_SD=0.01,
        constant_offset_SD=0.01,
        weight_decay=1e-5,
        mfcc_dim = 14,
        mfcc_loss_weight = 1.,
        channels = None,
        clip_mfcc_afer_go = False,

    ):
        super().__init__()

        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim
        self.neural_dim = neural_dim
        self.n_classes = n_classes
        self.nDays = nDays
        self.strideLen = strideLen
        self.kernelLen = kernelLen
        self.gaussianSmoothWidth = gaussianSmoothWidth
        self.bidirectional = bidirectional
        self.learning_rate = learning_rate
        self.white_noise_SD = white_noise_SD
        self.constant_offset_SD = constant_offset_SD
        self.weight_decay = weight_decay
        self.mfcc_loss_weight = mfcc_loss_weight
        self.clip_mfcc_afer_go = clip_mfcc_afer_go
        self.channels = channels
        self.mfcc_dim = mfcc_dim
        if channels is None:
            self.channels = np.arange(0, neural_dim)

        print("Resetting neural_dim based on channels")
        self.neural_dim = len(self.channels)
        neural_dim = self.neural_dim
        print("neural_dim", neural_dim, self.neural_dim)


        self.inputLayerNonlinearity = nn.Softsign()
        self.unfolder = nn.Unfold((self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen)
        self.mfcc_unfolder = nn.Unfold((self.strideLen, 1), dilation=1, padding=0, stride=self.strideLen)
        
        self.gaussianSmoother = GaussianSmoothing(neural_dim, 20, self.gaussianSmoothWidth, dim=1)

        # Per-day transformation weights
        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))
        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))

        for x in range(nDays):
            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)

        # GRU layer
        self.gru_decoder = nn.GRU(
            (neural_dim) * self.kernelLen,
            hidden_dim,
            layer_dim,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional,
        )

        for name, param in self.gru_decoder.named_parameters():
            if "weight_hh" in name:
                nn.init.orthogonal_(param)
            if "weight_ih" in name:
                nn.init.xavier_uniform_(param)

        # Fully connected output layer
        self.fc_decoder_out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, n_classes + 1)  # +1 for CTC blank
        self.mfcc_decoder = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, mfcc_dim*self.strideLen) 
        # Loss function
        self.ctc_loss = nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)
        self.l1oss = nn.L1Loss()

    def get_neural_embedding(self, neuralInput, dayIdx):
        """
        Forward pass of the model.
        neuralInput: (batch, time, features)
        dayIdx: Session index
        """

        #channel selection
        neuralInput = neuralInput[:, :, self.channels].contiguous()
        neuralInput = torch.permute(neuralInput, (0, 2, 1))
        neuralInput = self.gaussianSmoother(neuralInput)
        neuralInput = torch.permute(neuralInput, (0, 2, 1))

        # Apply day-specific transformations
        dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)
        transformedNeural = torch.einsum("btd,bdk->btk", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)
        transformedNeural = self.inputLayerNonlinearity(transformedNeural)

        # Apply unfolding (sliding window)
        stridedInputs = torch.permute(
            self.unfolder(torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)), (0, 2, 1)
        )

        # Initialize GRU hidden state
        h0 = torch.zeros(
            self.layer_dim * (2 if self.bidirectional else 1),
            transformedNeural.size(0),
            self.hidden_dim,
            device=self.device
        ).requires_grad_()

        # Apply GRU
        hid, _ = self.gru_decoder(stridedInputs, h0.detach())
        return hid
    def forward(self, neuralInput, dayIdx):
       
        hid = self.get_neural_embedding(neuralInput, dayIdx)
        # Final output layer
        phoneme_logits = self.fc_decoder_out(hid)
        mfcc_pred = self.mfcc_decoder(hid)
        return phoneme_logits, mfcc_pred

    def training_step(self, batch, batch_idx):
        """
        Training step - Runs forward pass, computes loss, and returns it for backprop.
        """
        X = batch["neural_feats"]
        y = batch["phone_seq"]
        X_len = batch["neural_time_bins"]
        y_len = batch["phone_seq_len"]
        dayIdx = batch["day"]
        sentence = batch["sentence"]
        MFCC = batch["mfcc"]

        #unfold MFCC

        # MFCC = torch.permute(MFCC, (0, 2, 1))

        mfcc_list = []
        for idx in range(len(batch["mfcc"])):
            if self.clip_mfcc_afer_go:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx][batch["go_onset"][idx]:]))
            else:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx]))



        MFCC = pad_sequence(mfcc_list, batch_first=True)
        MFCC = MFCC.to(self.device)

        MFCC = torch.permute(self.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))


        # X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # Noise augmentation
        if self.white_noise_SD > 0:
            X += torch.randn(X.shape, device=self.device) * self.white_noise_SD
        if self.constant_offset_SD > 0:
            X += torch.randn([X.shape[0], 1, X.shape[2]], device=self.device) * self.constant_offset_SD

        # Forward pass
        pred, mfcc_pred = self.forward(X, dayIdx)

        # Compute CTC Loss
        ctc_loss = self.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
            y_len,
        )



        min_seq_len = min(MFCC.shape[1],mfcc_pred.shape[1])
        mfcc_pred = mfcc_pred[:, :min_seq_len, :]
        MFCC = MFCC[:, :min_seq_len, :]
        
        l1_loss = self.l1oss(
            mfcc_pred,
            MFCC,
        )
        loss = ctc_loss + self.mfcc_loss_weight * l1_loss
        

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation step - Computes loss and CER.
        """
        X = batch["neural_feats"]
        y = batch["phone_seq"]
        X_len = batch["neural_time_bins"]
        y_len = batch["phone_seq_len"]
        dayIdx = batch["day"]
        sentence = batch["sentence"]
        MFCC = batch["mfcc"]

        mfcc_list = []
        for idx in range(len(batch["mfcc"])):
            if self.clip_mfcc_afer_go:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx][batch["go_onset"][idx]:]))
            else:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx]))



        MFCC = pad_sequence(mfcc_list, batch_first=True)
        #unfold MFCC

        MFCC = torch.permute(self.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))
        # MFCC = torch.permute(MFCC, (0, 2, 1))

        MFCC = MFCC.to(self.device)
        
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        pred, mfcc_pred = self.forward(X, dayIdx)

        ctc_loss = self.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
            y_len,
        )

        min_seq_len = min(MFCC.shape[1],mfcc_pred.shape[1])
        mfcc_pred = mfcc_pred[:, :min_seq_len, :]
        MFCC = MFCC[:, :min_seq_len, :]

        l1_loss = self.l1oss(
            mfcc_pred,
            MFCC,
        )

        loss = ctc_loss + self.mfcc_loss_weight * l1_loss
        

        # Compute CER (Phoneme Error Rate)
        total_edit_distance, total_seq_length = 0, 0
        for i in range(pred.shape[0]):
            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / self.strideLen), :], dim=-1)
            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)
            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()

            trueSeq = y[i][:y_len[i]].cpu().numpy()
            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())
            total_edit_distance += matcher.distance()
            total_seq_length += len(trueSeq)

        cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_CER", cer, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.999),
                                      eps=1e-8,)

        scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}


################################################# MAMBA


class MambaEncoder(nn.Module):
    def __init__(self, input_dim, d_model, num_layers=4, dropout=0.0, **kwargs):
        super().__init__()

        configuration = MambaConfig(hidden_size=d_model,num_hidden_layers=num_layers, **kwargs)
        self.mamba_model = MambaModel(configuration)
        self.input_proj = nn.Linear(input_dim, d_model)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.dropout(x)
        x = self.mamba_model(inputs_embeds=x).last_hidden_state
        x = self.norm(x)
        return x
    



class LightningMambaDecoder_MFCC_v3(pl.LightningModule):
    def __init__(
        self,
        neural_dim,
        n_classes,
        hidden_dim,
        layer_dim,
        nDays=24,
        dropout=0.1,
        strideLen=4,
        kernelLen=14,
        gaussianSmoothWidth=0,
        bidirectional=False,
        learning_rate=1e-3,
        white_noise_SD=0.01,
        constant_offset_SD=0.01,
        weight_decay=1e-5,
        mfcc_dim = 14,
        mfcc_loss_weight = 1.,
        channels = None,
        clip_mfcc_afer_go = False,

    ):
        super().__init__()

        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim
        self.neural_dim = neural_dim
        self.n_classes = n_classes
        self.nDays = nDays
        self.strideLen = strideLen
        self.kernelLen = kernelLen
        self.gaussianSmoothWidth = gaussianSmoothWidth
        self.bidirectional = bidirectional
        self.learning_rate = learning_rate
        self.white_noise_SD = white_noise_SD
        self.constant_offset_SD = constant_offset_SD
        self.weight_decay = weight_decay
        self.mfcc_loss_weight = mfcc_loss_weight
        self.clip_mfcc_afer_go = clip_mfcc_afer_go
        self.channels = channels
        self.mfcc_dim = mfcc_dim
        if channels is None:
            self.channels = np.arange(0, neural_dim)

        print("Resetting neural_dim based on channels")
        self.neural_dim = len(self.channels)
        neural_dim = self.neural_dim
        print("neural_dim", neural_dim, self.neural_dim)


        self.inputLayerNonlinearity = nn.Softsign()
        self.unfolder = nn.Unfold((self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen)
        self.mfcc_unfolder = nn.Unfold((self.strideLen, 1), dilation=1, padding=0, stride=self.strideLen)
        
        self.gaussianSmoother = GaussianSmoothing(neural_dim, 20, self.gaussianSmoothWidth, dim=1)

        # Per-day transformation weights
        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))
        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))

        for x in range(nDays):
            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)


        # Mamba Encoder
        self.mamba_encoder = MambaEncoder(
            input_dim = neural_dim* self.kernelLen,
            d_model= 2*hidden_dim if self.bidirectional else hidden_dim, #for compatibility with the GRU layer
            num_layers=layer_dim,
            dropout=dropout,
        )

        # GRU layer
        # self.gru_decoder = nn.GRU(
        #     (neural_dim) * self.kernelLen,
        #     hidden_dim,
        #     layer_dim,
        #     batch_first=True,
        #     dropout=dropout,
        #     bidirectional=bidirectional,
        # )

        # for name, param in self.gru_decoder.named_parameters():
        #     if "weight_hh" in name:
        #         nn.init.orthogonal_(param)
        #     if "weight_ih" in name:
        #         nn.init.xavier_uniform_(param)

        # Fully connected output layer
        self.fc_decoder_out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, n_classes + 1)  # +1 for CTC blank
        self.mfcc_decoder = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, mfcc_dim*self.strideLen) 
        # Loss function
        self.ctc_loss = nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)
        self.l1oss = nn.L1Loss()

    def get_neural_embedding(self, neuralInput, dayIdx):
        """
        Forward pass of the model.
        neuralInput: (batch, time, features)
        dayIdx: Session index
        """

        #channel selection
        neuralInput = neuralInput[:, :, self.channels].contiguous()
        neuralInput = torch.permute(neuralInput, (0, 2, 1))
        neuralInput = self.gaussianSmoother(neuralInput)
        neuralInput = torch.permute(neuralInput, (0, 2, 1))

        # Apply day-specific transformations
        dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)
        transformedNeural = torch.einsum("btd,bdk->btk", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)
        transformedNeural = self.inputLayerNonlinearity(transformedNeural)

        # Apply unfolding (sliding window)
        stridedInputs = torch.permute(
            self.unfolder(torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)), (0, 2, 1)
        )


        # Apply GRU
        hid= self.mamba_encoder(stridedInputs)
        return hid
    def forward(self, neuralInput, dayIdx):
       
        hid = self.get_neural_embedding(neuralInput, dayIdx)
        # Final output layer
        phoneme_logits = self.fc_decoder_out(hid)
        mfcc_pred = self.mfcc_decoder(hid)
        return phoneme_logits, mfcc_pred

    def training_step(self, batch, batch_idx):
        """
        Training step - Runs forward pass, computes loss, and returns it for backprop.
        """
        X = batch["neural_feats"]
        y = batch["phone_seq"]
        X_len = batch["neural_time_bins"]
        y_len = batch["phone_seq_len"]
        dayIdx = batch["day"]
        sentence = batch["sentence"]
        MFCC = batch["mfcc"]

        #unfold MFCC

        # MFCC = torch.permute(MFCC, (0, 2, 1))

        mfcc_list = []
        for idx in range(len(batch["mfcc"])):
            if self.clip_mfcc_afer_go:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx][batch["go_onset"][idx]:]))
            else:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx]))



        MFCC = pad_sequence(mfcc_list, batch_first=True)
        MFCC = MFCC.to(self.device)

        MFCC = torch.permute(self.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))


        # X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # Noise augmentation
        if self.white_noise_SD > 0:
            X += torch.randn(X.shape, device=self.device) * self.white_noise_SD
        if self.constant_offset_SD > 0:
            X += torch.randn([X.shape[0], 1, X.shape[2]], device=self.device) * self.constant_offset_SD

        # Forward pass
        pred, mfcc_pred = self.forward(X, dayIdx)

        # Compute CTC Loss
        ctc_loss = self.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
            y_len,
        )



        min_seq_len = min(MFCC.shape[1],mfcc_pred.shape[1])
        mfcc_pred = mfcc_pred[:, :min_seq_len, :]
        MFCC = MFCC[:, :min_seq_len, :]
        
        l1_loss = self.l1oss(
            mfcc_pred,
            MFCC,
        )
        loss = ctc_loss + self.mfcc_loss_weight * l1_loss
        

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation step - Computes loss and CER.
        """
        X = batch["neural_feats"]
        y = batch["phone_seq"]
        X_len = batch["neural_time_bins"]
        y_len = batch["phone_seq_len"]
        dayIdx = batch["day"]
        sentence = batch["sentence"]
        MFCC = batch["mfcc"]

        mfcc_list = []
        for idx in range(len(batch["mfcc"])):
            if self.clip_mfcc_afer_go:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx][batch["go_onset"][idx]:]))
            else:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx]))



        MFCC = pad_sequence(mfcc_list, batch_first=True)
        #unfold MFCC

        MFCC = torch.permute(self.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))
        # MFCC = torch.permute(MFCC, (0, 2, 1))

        MFCC = MFCC.to(self.device)
        
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        pred, mfcc_pred = self.forward(X, dayIdx)

        ctc_loss = self.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.kernelLen) / self.strideLen).to(torch.int32),
            y_len,
        )

        min_seq_len = min(MFCC.shape[1],mfcc_pred.shape[1])
        mfcc_pred = mfcc_pred[:, :min_seq_len, :]
        MFCC = MFCC[:, :min_seq_len, :]

        l1_loss = self.l1oss(
            mfcc_pred,
            MFCC,
        )

        loss = ctc_loss + self.mfcc_loss_weight * l1_loss
        

        # Compute CER (Phoneme Error Rate)
        total_edit_distance, total_seq_length = 0, 0
        for i in range(pred.shape[0]):
            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / self.strideLen), :], dim=-1)
            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)
            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()

            trueSeq = y[i][:y_len[i]].cpu().numpy()
            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())
            total_edit_distance += matcher.distance()
            total_seq_length += len(trueSeq)

        cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_CER", cer, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.999),
                                      eps=1e-8,)

        scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
