import datetime
import math
from typing import ForwardRef

import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat

import src.models.nn.utils as U
import src.utils as utils
import src.utils.config
from src.models.sequence.block import SequenceResidualBlock
from src.models.nn.components import Normalization
from src.models.sequence import ff
from src.models.sequence import mha
import torch
import torch.nn as nn


class Encoder(nn.Module):
    """Encoder abstraction
    Accepts a tensor and optional kwargs. Outside of the main tensor, all other arguments should be kwargs.
    Returns a tensor and optional kwargs.
    Encoders are combined via U.PassthroughSequential which passes these kwargs through in a pipeline. The resulting kwargs are accumulated and passed into the model backbone.

    """

    def forward(self, x, **kwargs):
        """
        x: input tensor
        *args: additional info from the dataset (e.g. sequence lengths)

        Returns:
        y: output tensor
        *args: other arguments to pass into the model backbone
        """
        return x, {}

class PositionalIDEncoder(Encoder):
    def forward(self, x):
        position_ids = torch.arange(x.shape[-1], dtype=torch.long, device=x.device)
        position_ids = repeat(position_ids, 'l -> b l', b=x.shape[0])
        return x, { 'position_ids': position_ids }

# Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py
class PositionalEncoder(Encoder):
    r"""Inject some information about the relative or absolute position of the tokens
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoder(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=16384, pe_init=None):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        if pe_init is not None:
            self.pe = nn.Parameter(torch.empty(max_len, 1, d_model))
            nn.init.normal_(self.pe, 0, pe_init)
            # self.pe = pe.unsqueeze(1)
        else:
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0.0, max_len).unsqueeze(1)
            div_term = torch.exp(
                -math.log(10000.0) * torch.arange(0.0, d_model, 2.0) / d_model
            )
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            self.register_buffer("pe", pe)

        self.attn_mask = None

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
            lens: actual lengths of sequences
        Shape:
            x: [l_sequence, n_batch, d_model]
            Returns: [l_sequence, n_batch, d_model]
            attn_mask: [l_sequence, l_sequence]
            padding_mask:
        """

        x = x + self.pe[: x.size(-2)]
        return self.dropout(x)


class ClassEmbedding(Encoder):
    # Should also be able to define this by subclassing Embedding
    def __init__(self, n_classes, d_model):
        super().__init__()
        self.embedding = nn.Embedding(n_classes, d_model)

    def forward(self, x, y):
        x = x + self.embedding(y).unsqueeze(-2)  # (B, L, D)
        return x


class Conv1DEncoder(Encoder):
    def __init__(self, d_input, d_model, kernel_size=25, stride=1, padding='same'):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels=d_input,
            out_channels=d_model,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

    def forward(self, x):
        # BLD -> BLD
        x = self.conv(x.transpose(1, 2)).transpose(1, 2)
        return x

class LayerEncoder(Encoder):
    """Use an arbitary SequenceModule layer"""

    def __init__(self, d_model, prenorm=False, norm='layer', layer=None):
        super().__init__()

        # Simple stack of blocks
        layer["transposed"] = False
        self.layer = SequenceResidualBlock(
            d_input=d_model,
            prenorm=prenorm,
            layer=layer,
            residual='R',
            norm=norm,
            pool=None,
        )

    def forward(self, x):
        x, _ = self.layer(x) # Discard state
        return x


class TimestampEmbeddingEncoder(Encoder):
    """
    General time encoder for Pandas Timestamp objects (encoded as torch tensors).
    See MonashDataset for an example of how to return time features as 'z's.
    """

    cardinalities = {
        'day': (1, 31),
        'hour': (0, 23),
        'minute': (0, 59),
        'second': (0, 59),
        'month': (1, 12),
        'year': (1950, 2010), # (1800, 3000) used to be (1970, datetime.datetime.now().year + 1) but was not enough for all datasets in monash
        'dayofweek': (0, 6),
        'dayofyear': (1, 366),
        'quarter': (1, 4),
        'week': (1, 53),
        'is_month_start': (0, 1),
        'is_month_end': (0, 1),
        'is_quarter_start': (0, 1),
        'is_quarter_end': (0, 1),
        'is_year_start': (0, 1),
        'is_year_end': (0, 1),
        'is_leap_year': (0, 1),
    }

    def __init__(self, d_model, table=False, features=None):
        super().__init__()
        self.table = table
        self.ranges = {k: max_val - min_val + 2 for k, (min_val, max_val) in self.cardinalities.items()} # padding for null included

        if features is None:
            pass
        else:
            self.cardinalities = {k: v for k, v in self.cardinalities.items() if k in features}

        if table:
            self.embedding = nn.ModuleDict({
                attr: nn.Embedding(maxval - minval + 2, d_model, padding_idx=0)
                for attr, (minval, maxval) in self.cardinalities.items()
            })
        else:
            self.embedding = nn.ModuleDict({
                attr: nn.Linear(1, d_model)
                for attr in self.cardinalities
            })



    def forward(self, x, timestamps=None):
        for attr in timestamps:
            mask = timestamps[attr] == -1
            timestamps[attr] = timestamps[attr] - self.cardinalities[attr][0]
            timestamps[attr][mask] = 0
            if self.table:
                x = x + self.embedding[attr](timestamps[attr].to(torch.long))
            else:
                x = x + self.embedding[attr]((2 * timestamps[attr] / self.ranges[attr] - 1).unsqueeze(-1))

            #x = x + self.embedding(timestamps[attr].to(torch.float)).unsqueeze(1)
        return x


class TimeEncoder(Encoder):
    def __init__(self, n_tokens_time, d_model, timeenc=0):
        super().__init__()

        self.timeenc = timeenc
        if self.timeenc == 0:
            self.encoders = nn.ModuleList(
                [nn.Embedding(v, d_model) for v in n_tokens_time]
            )
        else:
            self.encoders = nn.Linear(len(n_tokens_time), d_model)
        self.mask_embed = nn.Embedding(2, d_model)

    def forward(self, x, mark=None, mask=None):
        assert mark is not None and mask is not None, "Extra arguments should be returned by collate function"
        if self.timeenc == 0:
            assert mark.size(-1) == len(self.encoders)
            embeddings = [
                embed(z) for embed, z in zip(self.encoders, torch.unbind(mark, dim=-1))
            ]
            time_encode = torch.sum(torch.stack(embeddings), dim=0)
        else:
            time_encode = self.encoders(mark)
        mask_encode = self.mask_embed(mask.squeeze(-1))
        return x + time_encode + mask_encode  # (B, L, d_model)


class PackedEncoder(Encoder):
    def forward(self, x, len_batch=None):
        assert len_batch is not None
        x = nn.utils.rnn.pack_padded_sequence(
            x, len_batch.cpu(), enforce_sorted=False, batch_first=True,
        )
        return x


class OneHotEncoder(Encoder):
    def __init__(self, n_tokens, d_model):
        super().__init__()
        assert n_tokens <= d_model
        self.d_model = d_model

    def forward(self, x):
        return F.one_hot(x.squeeze(-1), self.d_model).float()


class Conv2DPatchEncoder(Encoder):

    """
    For encoding images into a sequence of patches.
    """

    def __init__(self, d_input, d_model, filter_sizes, flat=False):

        """
        d_input: dim of encoder input (data dimension)
        d_model: dim of encoder output (model dimension)
        filter_sizes: tuple with fh, fw
        flat: if image is flattened from dataloader (like in cifar),
            then we need to reshape back to 2D before conv
        """

        fh, fw = filter_sizes
        self.flat = flat

        super().__init__()
        assert len(filter_sizes) == 2

        self.encoder = nn.Conv2d(d_input, d_model, kernel_size=(fh, fw), stride=(fh, fw))

    def forward(self, x):

        """
        x shape expected = [b, h, w, c]
        returns tuple with x, with new shape = [b, seq_len, c_out]

        """

        x = rearrange(x, 'b h w c -> b c h w')
        x = self.encoder(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        return x


class Encoder3DSegmentation(Encoder):
    _name_ = "encoder_3d_segmentation"

    def __init__(self, patch_size=16, hidden_dim=128):
        super().__init__()
        # Assume channels is 1 for now
        self.hidden_dim = hidden_dim
        self.patch_size = patch_size
        self.linear = nn.Linear(
            self.patch_size * self.patch_size * self.patch_size, 
            self.hidden_dim
            )
    
    def forward(self, x):
        # x is of shape (batch_size,16,16,16,simulation_steps)
        # First flatten the data
        x = x.view(
            x.shape[0], 
            self.patch_size * self.patch_size * self.patch_size,
            -1,
            )
        # Encode the data
        x = torch.permute(x, (0, 2, 1))
        x = self.linear(x)

        # Output dimensions are (batch_size, simulation_steps,hidden_dim)
        return x


class TimeseriesSyntheticsEncoder(Encoder):
    _name_ = "timeseries_synthetics"

    def __init__(
            self, num_states, loan_pool_size, d_model, macro_features=1,**kwargs):
        super().__init__()
        self.loan_pool_size = loan_pool_size
        self.num_states = num_states
        self.d_model = d_model
        self.linear = nn.Linear(
            self.loan_pool_size * (self.num_states + 1) + macro_features, 
            self.d_model)
    
    def forward(self, x):
        #x has shape (batch_size,(num_state+1)*loan_pool_size+1,simulation_steps)
        x = torch.permute(x, (0, 2, 1))
        x = self.linear(x)
        # Output dimensions are (batch_size, simulation_steps,hidden_dim, )
        return x


class PositionalEncoderLinear(Encoder):
    _name_ = "positional_linear"
    
    def __init__(self, num_states, 
                 loan_pool_size, d_model, dropout=0):
        super().__init__()
        self.TsEncoder = TimeseriesSyntheticsEncoder(
                            num_states, loan_pool_size, d_model)
        self.PosEncoder = PositionalEncoder(d_model, dropout)
        self.d_model = d_model

    def forward(self, x):
        x = self.TsEncoder(x)
        x = x * torch.sqrt(torch.tensor(self.d_model))
        x = self.PosEncoder(x)
        return x
    

class CalendarPositionalEncoder(Encoder):
    _name_ = "calendar_positional_linear"
    
    def __init__(self, num_states,
                 loan_pool_size, d_model, dropout=0):
        super().__init__()
        self.TsEncoder = TimeseriesSyntheticsEncoder(
                             num_states, loan_pool_size, d_model)
        self.PosEncoder = PositionalEncoder(d_model, dropout)
        self.pe = self.PosEncoder.pe.cuda()
        self.dropout = self.PosEncoder.dropout
        self.d_model = d_model

    def forward(self, x, I):
        x = self.TsEncoder(x)
        x = x * torch.sqrt(torch.tensor(self.d_model))
        # x dimensions are (batch_size, simulation_steps, hidden_dim, )
        # I has shape (batch_size)
        # self.pe has shape (max_seq_len, hidden_dim)
        I = I.long()
        for i in range(len(I)):
            x[i, :, :] = x[i, :, :] + self.pe[I[i]:I[i]+x.shape[-2], :]
  
        x = self.dropout(x)
        return x

class AbsoluteTimeEncoder(Encoder):
    _name_ = "absolute_time"
    
    def __init__(self, num_states, 
                 loan_pool_size, d_model, dropout=0):
        super().__init__()
        self.TsEncoder = TimeseriesSyntheticsEncoder(
                            num_states, loan_pool_size, d_model,
                            macro_features=2)
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, I):
        time_feature = torch.zeros(x.shape[0], 1, x.shape[2])
        for i in range(len(I)):
            time_feature[i, :, :] = torch.arange(
                start=int(I[i]), end=int(I[i])+x.shape[2])
        time_feature /= 1000
        time_feature = time_feature.cuda()
        x = torch.cat([x ,time_feature], dim=1)
        x = self.TsEncoder(x)
        # x dimensions are (batch_size, simulation_steps, hidden_dim, )
        # I has shape (batch_size)
        x = self.dropout(x)
        return x

def log_grad(name):
    """
    Returns a hook function that prints the norm of the gradient
    after the backward pass updates this parameter.
    """
    def hook_fn(grad):
        print(f"[GRAD] {name} grad norm: {grad.norm().item():.6f}")
    return hook_fn


class StackedEncoder(Encoder):
    """Stacks all the units of the input to remove the unit dimension, and create a larger feture dimension."""
    _name_ = "stacked_encoder"

    def __init__(self, d_model, nr_features, nr_units, **kwargs):
        super().__init__()
        
        self.d_model = d_model
        self.input_feature_dim = nr_features
        self.nr_units = nr_units
        self.stacked_encoder = nn.Linear(self.input_feature_dim * nr_units, self.d_model)
    
    def forward(self, x, **kwargs):
        
         #x shape = (batch_size, nr_features, nr_units, nr_timesteps)
        x = x.reshape(x.shape[0],x.shape[1]*x.shape[2],x.shape[3])
        x = x.permute(0,2,1)
        x = self.stacked_encoder(x)
        return x, {"nr_units": self.nr_units}
        
        

class SetEncoder(Encoder):
    """
    A Set-based Encoder for the Set-Sequence model architecture.

    This encoder operates on a batch of time series corresponding to a set of units (e.g., loans or stocks),
    where each unit's behavior may depend on both its individual features and shared cross-sectional structure.
    The encoder computes a permutation-invariant summary across units at each timestep using `m_2`,
    and augments each unit's time series with this summary before passing it through a feedforward network `m_3`.

    Attributes:
        m_2 (nn.Module): The set function module (e.g., MLP, attention-based) that generates shared set summaries.
        m_3 (nn.Module): A feedforward network that processes the concatenated unit-specific and set-level features.
        use_layer_norm_set (bool): Whether to use LayerNorm on the final output.
        architecture (str): Type of architecture used for the set function (e.g., "MLP", "MHA").
        num_states (int): Number of latent state variables per unit.
        pool_size (int): Number of units per batch (loan pool size).
        common_pool_embedding_dim (int): Dimensionality of the set summary vector.
        feature_embedding_dim (int): Dimensionality of per-unit feature embedding.
        chunk_size (int): Number of timesteps used in the temporal chunk for summary generation.
        n_attn_summary_statistics (bool): Whether attention summary statistics are used.
        dropout (float): Dropout rate used in set and feedforward layers.
        projection (bool): Whether to include projection layer in `m_2`.
        set_var_layer_norm (bool): Whether to apply variance-based normalization across set features.
        aggregator_only (bool): Whether to only use the aggregator without unit features (ablation mode).
        include_y_in_output (bool): Whether to return the outcome feature `y` in the output dictionary.
        return_raw_input (bool): Whether to return the raw reshaped input instead of encoding.
    """
    _name_ = "set_encoder"

    def __init__(
            self,
            num_states,
            loan_pool_size,
            d_model,
            common_pool_embedding_dim=2,
            feature_embedding_dim=5,
            debug = False,
            architecture="MLP",
            nr_attention_heads=128,
            use_layer_norm_set=False,
            chunk_size=1,
            n_attn_summary_statistics=False,
            dropout=0,
            expand=2,
            projection=False,
            set_var_layer_norm=False,
            aggregator_only=False,
            id = False,
            v_dim=5,
            include_y_in_output=False,
            return_raw_input=False,
            ):
        super().__init__()
        self.use_layer_norm_set = use_layer_norm_set
        self.debug = debug
        self.d_model = d_model
        self.feature_embedding_dim = feature_embedding_dim
        self.include_y_in_output = include_y_in_output
        self.return_raw_input = return_raw_input
        nr_non_state_feature_per_unit = 2
        
        if not num_states:
            num_states = 0
        self.input_feature_dim = num_states+nr_non_state_feature_per_unit #1
        self.num_states = num_states
        self.pool_size = loan_pool_size
        self.common_pool_embedding_dim = common_pool_embedding_dim
        self.architecture = architecture
        self.nr_attention_heads = nr_attention_heads
        self.chunk_size = chunk_size
        self.n_attn_summary_statistics = n_attn_summary_statistics
        self.dropout = dropout
        self.expand = expand
        self.projection = projection
        self.set_var_layer_norm = set_var_layer_norm
        self.aggregator_only = aggregator_only
        self.id = id
        self.v_dim = v_dim
        if self.use_layer_norm_set:
            self.norm = nn.LayerNorm(self.input_feature_dim)
        # ff operates on the last dimension of the input. 
        if not debug:
            self.m_2 = M2(
                d_input=self.input_feature_dim,
                d_output=self.common_pool_embedding_dim,
                m1_output_dim=self.feature_embedding_dim,
                chunk_size=self.chunk_size,
                dropout=self.dropout,
                projection=self.projection,
                architecture=self.architecture,
                nr_attention_heads=self.nr_attention_heads,
                n_attn_summary_statistics=self.n_attn_summary_statistics,
                v_dim=self.v_dim,
            )
        if not self.id:
            self.m_3 = ff.FF(
                d_input=(self.common_pool_embedding_dim + self.input_feature_dim),
                d_output=self.d_model,
                dropout=self.dropout,
                expand=self.expand,
            )
        else:
            self.m_3 = lambda x: (x, {})
        

    def forward(self, x):
        """
        Forward pass for the SetEncoder.

        Args:
            x (torch.Tensor): Input tensor of shape 
                (batch_size, num_covariates, num_units, num_timesteps),
                where num_covariates includes state variables and optionally the target y.

        Returns:
            Tuple[torch.Tensor, dict]:
                - Encoded representation of shape 
                  (batch_size * num_units, num_timesteps, d_model)
                - Dictionary containing metadata:
                    - "y": Optional, the target variable y if `include_y_in_output` is True.
                    - "nr_units": The number of units (e.g., loans or stocks) per sample.
        """
        #Reshape x to x_in [BZ, units_per_simulation, L, num_states+1]
        x_in = torch.transpose(x, 1, 2)
        x_in = torch.transpose(x_in, 2, 3)
        
        if self.include_y_in_output:
            y  = x_in[:,:,:,0]

        if self.return_raw_input:
            hat_hat_x = torch.reshape(x_in, (-1, x_in.shape[2], x_in.shape[3]))
            return hat_hat_x, {"y": y, "nr_units": x.shape[2]}
        # 4. Apply m1 to get x_1 [BZ, units_per_simulation, L, feature_embedding_dim]
        if not self.debug:
            # 5. Apply m2 to get x_2 [BZ, L, common_pool_embedding_dim]
            x_2 = self.m_2(x_in)
        else:
            if self.n_attn_summary_statistics and self.architecture in ["MHA", "gated_selection"]:
                x_2 = torch.zeros((x_in.shape[0],x_in.shape[1],x_in.shape[2],self.common_pool_embedding_dim)).cuda()
            else:
                x_2 = torch.zeros((x_in.shape[0],x_in.shape[2],self.common_pool_embedding_dim)).cuda()
        # 7. Expand [BZ, units_per_simulation, L, common_pool_embedding_dim + 1]
        if self.n_attn_summary_statistics and self.architecture in ["MHA", "gated_selection"]:
            pass
        else:
            x_2 = torch.unsqueeze(x_2, dim=1)
            x_2 = x_2.expand(-1, x.shape[2], -1, -1)
        # 8. Cat with x_in [BZ, units_per_simulation, L, num_states+embedding_dim]
        if self.aggregator_only:
            x_in = torch.zeros_like(x_in)
        
        hat_x = torch.cat([x_in, x_2], dim=3)

        # 9. Reshape to [BZ * units_per_simulation, L, num_states+2+embedding_dim]
        hat_hat_x = torch.reshape(hat_x, (-1, hat_x.shape[2], hat_x.shape[3]))
        # 10. Apply m3 to get x_3 [BZ * units_per_simulation, L, d_model]
        x_3, _ = self.m_3(hat_hat_x)
        x_original_reshaped = torch.reshape(x_in, (-1, x_in.shape[2], x_in.shape[3]))
        if self.use_layer_norm_set and x_original_reshaped.shape == x_3.shape: #i.e. not in encoder
            # Merge the first two dimensions of x_3 and x_original_reshaped
            # so that both have shape (BZ * units_per_simulation, d_model)
            x_3_old_shape = x_3.shape
            x_3 = torch.reshape(x_3, (-1, x_3.shape[2]))
            x_original_reshaped = torch.reshape(x_original_reshaped, (-1, x_original_reshaped.shape[2]))
            x_3 = x_original_reshaped + self.norm(x_3)
            x_3 = torch.reshape(x_3, x_3_old_shape)

        if self.include_y_in_output:
            return x_3, {"y": y, "nr_units": x.shape[2]}
        return x_3, {"nr_units": x.shape[2]} # x_3 shape


class M2(nn.Module):
    def __init__(
        self,
        d_input,
        d_output,
        m1_output_dim,
        chunk_size=1,
        dropout=0,
        projection=False,
        architecture="MLP",
        nr_attention_heads=128,
        n_attn_summary_statistics=False,
        v_dim=5,
    ):
        """
        A hybrid module that combines local chunk-based transformation and global set aggregation.

        This module performs two main operations:
        1. A local chunk-wise feature transformation across time.
        2. A permutation-invariant aggregation across the cross-sectional units.

        The architecture supports multiple aggregation modes:
        - "MLP": mean-pooling followed by an FFN.
        - "MHA": cross-unit multi-head attention followed by an FFN.
        - "gated_selection": attention-like learnable gating based on token similarity.

        Args:
            d_input (int): Input feature dimension per timestep.
            d_output (int): Output dimension after aggregation.
            m1_output_dim (int): Output dimension of the chunk-based transform (M1).
            chunk_size (int): Temporal window size for local feature extraction.
            dropout (float): Dropout rate applied in FF layers.
            projection (bool): If True, use an FF layer with expansion for projection in M1.
            architecture (str): Aggregation type: "MLP", "MHA", or "gated_selection".
            nr_attention_heads (int): Number of heads for MHA aggregation.
            n_attn_summary_statistics (bool): If True, skip reduction in MHA (used for diagnostics).
            v_dim (int): Internal value dimension for gated selection.
        """
        super().__init__()
        self.chunk_size = chunk_size
        self.m1_output_dim = m1_output_dim
        self.n_attn_summary_statistics = n_attn_summary_statistics
        self.architecture = architecture
        self.nr_attention_heads = nr_attention_heads
        self.v_dim = v_dim
        
        # -------------------------
        # 1) The old M1 sub-network
        # -------------------------
        if projection:
            expand = (m1_output_dim * 2.0) / d_input
            self.nn_m1 = ff.FF(
                d_input= d_input * chunk_size,
                d_output=self.m1_output_dim,
                dropout=dropout,
                expand=expand,
            )
        else:
            # If no projection, just do a single Linear
            self.nn_m1 = nn.Linear( d_input * chunk_size,self.m1_output_dim)

        # -------------------------
        # 2) The aggregator (old M2)
        # -------------------------
        # aggregator’s d_input is the M1 output dimension
        self.d_input = d_input
        self.d_output = d_output  # We'll set it depending on architecture
        if architecture == "MLP" or (not n_attn_summary_statistics):
            # d_output can be set externally or you can keep it as an attribute
            # If you want a separate param for aggregator’s output dimension, you can pass it in:
            # example if aggregator just matches dimension
            self.nn = ff.FF(d_input=self.m1_output_dim, d_output=self.d_output)
        
        if architecture == "MHA":
            # Possibly keep aggregator’s "d_output" the same as input, or define a separate param
            # or set differently
            self.nn = ff.FF(d_input=self.m1_output_dim, d_output=self.d_output)
            self.mha = mha.MultiheadAttention(
                d_model=self.m1_output_dim,
                n_heads=self.nr_attention_heads,
                causal=False,
            )
            self.proj_token_dim = nn.Linear(self.m1_output_dim, self.d_output)
            self.proj_pool_dim = nn.Linear(10, 1)  # For example, if you need it

        if architecture == "gated_selection":
            self.gate_net = AttentionLikeGate(d_input=self.m1_output_dim)
            self.V = nn.Linear(self.m1_output_dim, self.v_dim)
            self.proj_token_dim = nn.Linear(self.v_dim, self.d_output)
            self.K = nn.Linear(self.m1_output_dim, self.m1_output_dim)
        
        custom_lr = False
        if custom_lr:

            for name, param in self.named_parameters():
                param._optim = {"lr": 0.0, "weight_decay": 0.1}
                
                self.register_parameter(name.replace(".",""), param)
    
    def _m1_forward(self, x):
        """
        The chunk-based feed-forward from the old M1 class.
        
        x.shape: (B, P, L, m1_input_dim).
        Returns: (B, P, L, m1_output_dim).
        """
        BZ, units_per_simulation, L, d_input = x.size()
        # 1) Create left-padding by duplicating the first time-step chunk_size - 1 times
        first_value = x[:, :, 0:1, :]  # shape: [BZ, P, 1, d_input]
        padding = first_value.expand(-1, -1, self.chunk_size - 1, -1)
        x_padded = torch.cat((padding, x), dim=2)  # shape: [BZ, P, L + chunk_size - 1, d_input]

        # 2) Unfold to create sliding windows
        x_chunks = x_padded.unfold(dimension=2, size=self.chunk_size, step=1)
        # x_chunks => (BZ, P, L, chunk_size, d_input)
        x_chunks = x_chunks.contiguous().view(BZ, units_per_simulation, L, -1)
        # => (BZ, P, L, chunk_size * d_input)

        # 3) Pass through self.nn_m1
        x_chunks = x_chunks.view(-1, self.chunk_size * d_input)  # => (BZ*P*L, chunk_size*d_input)
        output = self.nn_m1(x_chunks)
        # If ff.FF returns a tuple, keep only the first
        if isinstance(output, tuple):
            output = output[0]
        # Reshape => (BZ, P, L, m1_output_dim)
        outputs = output.view(BZ, units_per_simulation, L, -1)

        return outputs

    def _mean(self, x: torch.Tensor, dim: int) -> torch.Tensor:
        """Take the mean of x along dimension dim."""
        return torch.mean(x, dim=dim)
    
    def gated_selection_forward(self, x, t_g=0):
        """
        x shape: (B, P, T, input_dim)
        - B = batch size
        - P = number of “units” or “loans” (the set dimension)
        - T = time steps
        - input_dim = input features per token

        We do gating based on the chunk-embedding at t_g=0,
        and sum across all timesteps with the same gating distribution.
        """
        B, P, T, D = x.shape

        # 1) Apply your chunk-based feedforward "M1" (e.g. local convolution)
        #    This transforms raw x into x1 of shape (B, P, T, m1_output_dim).
        x1 = self._m1_forward(x)  # shape => (B, P, T, m1_output_dim)

        # 2) Compute gating 'keys' from that chunk embedding
        #    K(...) -> shape (B, P, T, m1_output_dim)
        xk = self.K(x1)  

        # 3) Extract K from the single time step t_g (e.g. 0)
        #    e_tg => (B, P, m1_output_dim)
        e_tg = xk[:, :, t_g, :]

        # 4) Expand out (i, j) pairs to feed into your gating net
        e_tg_i = e_tg.unsqueeze(2).expand(-1, -1, P, -1)  # => (B, P, P, m1_output_dim)
        e_tg_j = e_tg.unsqueeze(1).expand(-1, P, -1, -1)  # => (B, P, P, m1_output_dim)
        e_tg_ij = torch.cat([e_tg_i, e_tg_j], dim=-1)     # => (B, P, P, 2*m1_output_dim)

        # 5) Pass that into gate_net to get gating logits
        gates_raw, _ = self.gate_net(e_tg_ij)             # => (B, P, P, 1)
        # Typically you'd do softmax over the last dimension:
        G = F.softmax(gates_raw.squeeze(-1), dim=-1)      # => (B, P, P)

        # 6) Now produce aggregator "values" from x1 using V(...)
        #    e.g. xv => (B, P, T, v_dim)
        xv = self.V(x1)

        # 7) Weighted sum across j with the gating distribution G_{i,j}
        #    x_gated => (B, P, T, v_dim)
        x_gated = torch.einsum("b i j, b j t d -> b i t d", G, xv)

        # 8) Optionally project the aggregator output to d_output
        #    => (B, P, T, d_output)
        x_gated = self.proj_token_dim(x_gated)

        return x_gated


    def mha_forward(self, x, dim=1):
        # unmerge the pool dimension with the batch dimension
        # x shape: (batch, pool, seq_len, token_dim)
        batch, pool, seq_len, token_dim = x.size()
        # Reshape to (batch * seq_len, pool, token_dim) for multi-head attention
        x = x.permute(0, 2, 1, 3).reshape(batch * seq_len, pool, token_dim) # This is very inefficient as it
        # scales linearly with the number of timesteps        
        try:
            attn_output, _ = self.mha(x)
        except Exception as e: # We have the execpt as mha does not work when the batch size is too large
            chunk_size = 10000
            num_chunks = (x.size(0) + chunk_size - 1) // chunk_size  # Calculate number of chunks
            # Initialize list to store outputs
            attn_outputs = []

            # Split into chunks, apply attention, and collect outputs
            for i in range(num_chunks):
                start_idx = i * chunk_size
                end_idx = min((i + 1) * chunk_size, x.size(0))
                x_chunk = x[start_idx:end_idx]
                attn_output_chunk, _ = self.mha(x_chunk)
                
                attn_outputs.append(attn_output_chunk)
            # Concatenate all chunk outputs along the batch dimension
            attn_output = torch.cat(attn_outputs, dim=0)
                    
        # Reshape back to (batch, seq_len, pool, token_dim)
        attn_output = attn_output.reshape(batch, seq_len, pool, token_dim).permute(0, 2, 1, 3)
        x = attn_output

        # Project along token_dim to d_output
        output = self.proj_token_dim(attn_output)  # Shape: (batch, pool, seq_len, d_output)
        
        if self.n_attn_summary_statistics:
            return output
        x = attn_output

        x_1 = self._mean(x, dim=dim)
        y, _ = self.nn(x_1)
        return y

    def forward(self, x, dim=1, debug=False):
        """
        x has shape (B, P, L, m1_input_dim) if we are going to do the chunk transform.
        If debug=True, we skip M1 and produce zeros or skip aggregator logic.
        """
        if not debug:
            if self.architecture == "gated_selection":
                return self.gated_selection_forward(x)
            elif self.architecture == "MHA":
                # 1) Apply M1 chunk-based transform
                x1 = self._m1_forward(x)  # shape => (B, P, L, m1_output_dim)
                return self.mha_forward(x1, dim=dim)
            elif self.architecture == "MLP":
                # 1) Apply M1 chunk-based transform
                x1 = self._m1_forward(x)  # shape => (B, P, L, m1_output_dim)
                x_1_mean = self._mean(x1, dim=dim)
                y, _ = self.nn(x_1_mean)
                return y
            else:
                raise ValueError(f"Unsupported architecture {self.architecture}")
        else:
            # Debug path: e.g. return zeros for aggregator
            if self.n_attn_summary_statistics and self.architecture in ["MHA", "gated_selection"]:
                return torch.zeros(
                    x.shape[0], x.shape[1], x.shape[2], self.m1_output_dim, device=x.device
                )
            else:
                # Just produce e.g. (B, L, aggregator_dim)
                return torch.zeros(
                    x.shape[0], x.shape[2], self.m1_output_dim, device=x.device
                )


class AttentionLikeGate(nn.Module):
    def __init__(self, d_input):
        """
        Args:
            d_input (int): Dimension of each vector (so input size is 2 * d_input).
        """
        super().__init__()

    def forward(self, x):
        """
        Args:
            x (Tensor): Input tensor of shape (batch, 2 * d_input),
                        where the first half is x_i and the second half is x_j.
        
        Returns:
            Tensor: Scaled dot product similarity (batch, 1).
        """
        d = x.shape[-1] // 2  # Compute D (since input is 2D)

        # Split input into two D-dimensional vectors
        x_i, x_j = torch.split(x, d, dim=-1)  # Each of shape (batch, D)

        # Pass each through the shared linear projection
        f_x_i = x_i  # (batch, D)
        f_x_j = x_j  # (batch, D)

        # Compute the dot product similarity
        dot_product = (f_x_i * f_x_j).sum(dim=-1, keepdim=True)
        
        # Compute the L2 norm along the last dimension for each tensor
        norm_f_x_i = torch.norm(f_x_i, p=2, dim=-1, keepdim=True)
        norm_f_x_j = torch.norm(f_x_j, p=2, dim=-1, keepdim=True)
        
        # To avoid division by zero, add a small epsilon
        epsilon = 1e-8
        
        # Compute cosine similarity by dividing the dot product by the product of the norms
        gate_value = dot_product / (norm_f_x_i * norm_f_x_j + epsilon)

        return gate_value, {}


# For every type of encoder/decoder, specify:
# - constructor class
# - list of attributes to grab from dataset
# - list of attributes to grab from model

registry = {
    "stop": Encoder,
    "id": nn.Identity,
    "embedding": nn.Embedding,
    "linear": nn.Linear,
    "position": PositionalEncoder,
    "position_id": PositionalIDEncoder,
    "class": ClassEmbedding,
    "pack": PackedEncoder,
    "time": TimeEncoder,
    "onehot": OneHotEncoder,
    "conv1d": Conv1DEncoder,
    "patch2d": Conv2DPatchEncoder,
    "timestamp_embedding": TimestampEmbeddingEncoder,
    "layer": LayerEncoder,
    "encoder_3d_segmentation": Encoder3DSegmentation,
    "timeseries_synthetics": TimeseriesSyntheticsEncoder,
    "positional_linear": PositionalEncoderLinear,
    "calendar_positional_linear": CalendarPositionalEncoder,
    "absolute_time": AbsoluteTimeEncoder,
    "set_encoder": SetEncoder,
    "stacked_encoder": StackedEncoder,
}
dataset_attrs = {
    "embedding": ["n_tokens"],
    "linear": ["d_input"],  # TODO make this d_data?
    "class": ["n_classes"],
    "time": ["n_tokens_time"],
    "onehot": ["n_tokens"],
    "conv1d": ["d_input"],
    "patch2d": ["d_input"],
    "timeseries_synthetics": ["num_states", "loan_pool_size"],
    "positional_linear": ["num_states", "loan_pool_size"],
    "calendar_positional_linear": ["num_states", "loan_pool_size"],
    "absolute_time": ["num_states", "loan_pool_size"],
    "set_encoder": ["num_states", "loan_pool_size"],
}
model_attrs = {
    "embedding": ["d_model"],
    "linear": ["d_model"],
    "position": ["d_model"],
    "class": ["d_model"],
    "time": ["d_model"],
    "onehot": ["d_model"],
    "conv1d": ["d_model"],
    "patch2d": ["d_model"],
    "timestamp_embedding": ["d_model"],
    "layer": ["d_model"],
    "timeseries_synthetics": ["d_model"],
    "positional_linear": ["d_model"],
    "calendar_positional_linear": ["d_model"],
    "absolute_time": ["d_model"],
    "set_encoder": ["d_model"],
    "stacked_encoder": ["d_model"]
}


def _instantiate(encoder, dataset=None, model=None):
    """Instantiate a single encoder"""
    if encoder is None:
        return None
    if isinstance(encoder, str):
        name = encoder
    else:
        name = encoder["_name_"]

    # Extract dataset/model arguments from attribute names
    dataset_args = utils.config.extract_attrs_from_obj(
        dataset, *dataset_attrs.get(name, [])
    )
    model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, []))

    # Instantiate encoder
    
    obj = utils.instantiate(registry, encoder, *dataset_args, *model_args)
    return obj


def instantiate(encoder, dataset=None, model=None):
    encoder = utils.to_list(encoder)
    return U.PassthroughSequential(
        *[_instantiate(e, dataset=dataset, model=model) for e in encoder]
    )
