import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
from einops import rearrange
import opt_einsum as oe
import math

optimized = True

if optimized:
    contract = oe.contract
else:
    contract = torch.einsum

from src.models.nn import LinearActivation, Activation, DropoutNd
from src.models.sequence.block_fft import BlockFFT
from src.models.sequence.long_conv_kernel import LongConvKernel

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.utils.train import OptimModule


def rolling_cumulative_sum(e, window_size=30):
    """
    e is shape (B, M, T):
       B = batch dimension
       M = cross-section dimension
       T = time dimension

    Returns a new tensor e_cum of the same shape, where
    e_cum[..., t] = sum of e[..., tau] for tau from (t - window_size + 1) up to t,
    with partial sums when t < window_size.
    """
    # Full cumsum along time dimension:
    e_cumsum = e.cumsum(dim=-1)  # shape still (B, M, T)

    # Shift cumsum by `window_size` steps to subtract from the original.
    # We pad the left side with zeros so that for t < window_size,
    # we still get partial sums (rather than subtracting "negative" indices).
    e_cumsum_shifted = F.pad(e_cumsum[..., :-window_size], (window_size, 0), value=0.0)

    # Rolling window sum: for each time t, sum of the last window_size residuals.
    e_cum = e_cumsum - e_cumsum_shifted

    return e_cum



class AttentionFactorModel(OptimModule):
    """
    If debug_market_factor is True, the model uses exactly one factor:
      F_t = mean(Y_t across assets)
    Otherwise, it does the normal multi-head attention approach.

    Steps for each time step t in normal mode:
      1) X_t -> keys => shape (d_attn)
      2) alpha_{k,t} = softmax( (keys . Q_k) / sqrt(d_attn) ) across M assets
      3) alpha_t in R^{M x K}
      4) F_t = alpha_t^T Y_t in R^K
      5) e_t = Y_t - beta F_t in R^M
    """
    def __init__(
        self,
        n_assets=600,    # M
        d_x=50,          # dimension of per-asset covariates
        n_factors=5,     # K: number of attention heads/factors
        d_attn=50,       # internal dimension for queries/keys
        debug_market_factor=False,  # <--- Our new debug flag
        use_factor_portfolio=True,
    ):
        super().__init__()
        self.n_assets  = n_assets
        self.d_x       = d_x
        self.n_factors = n_factors
        self.d_attn    = d_attn
        if n_factors == 0:
            no_factors = True
        else:
            no_factors = False
        self.debug_market_factor = debug_market_factor
        self.no_factors = no_factors

        self.nn_beta = False
        
        if not debug_market_factor and not no_factors:
            # Normal multi-factor approach
            #self.W = nn.Parameter(torch.randn(n_factors, d_attn, d_x) * 0.01)
            self.Q = nn.Parameter(torch.randn(n_factors, d_x) * 0.1)
            # Beta is (M, K)
            self.beta = nn.Parameter(torch.randn(n_assets, n_factors) * 0.1)
            self.beta.data[:, 0] = 1.0
            self.Q.data[0, :] = 0.0  # First row is pinned to zero for some reason
            gram_schmidt = True
            unit_vecs = False
            if gram_schmidt and n_factors > 1:
                block_size = d_x  # "Restart" Gram-Schmidt every d_x rows
                for i in range(1, n_factors):
                    v = self.Q.data[i]

                # Determine the block this vector belongs to
                # i.e. block_start .. i-1
                block_start = max(1, (i // block_size) * block_size)
                block_end = i

                # Orthogonalize v w.r.t. all rows in its block
                for j in range(block_start, block_end):
                    u = self.Q.data[j]
                    u_norm_sq = u.dot(u)
                    if u_norm_sq > 1e-8:
                        v -= u * (v @ u) / u_norm_sq

                # Normalize + scale
                norm_v = v.norm(p=2)
                if norm_v > 1e-12:
                    v /= norm_v
                v *= 0.1

                self.Q.data[i] = v
            elif unit_vecs:
                for i in range(1, n_factors):
                    zero_vec = torch.zeros_like(self.Q.data[0])
                    zero_vec[i] = 1.0
                    self.Q.data[i] = zero_vec
            learning_rate = 0
            weight_decay = 0
            self.Q._optim = {'lr': 0.03, 'weight_decay': 0.05}
            self.register_parameter("Q", self.Q)
            self.beta._optim = {'lr': 0.03, 'weight_decay': 0.05}
            self.register_parameter("beta", self.beta)
            #self.register("attn_factor_Q", self.Q.data, lr= learning_rate, wd= weight_decay)
            #self.register("attn_factor_beta", self.beta.data, lr= learning_rate, wd= weight_decay)
            
        elif no_factors:
            # If debugging, only 1 factor, treat it as a "market factor"
            self.n_factors = 1
            # We won't even define W, Q
            # We'll just define Beta as (M, 1)
            #initialize to all ones
            self.beta = nn.Parameter(torch.ones(n_assets, 1))
        else:
            pass

        # factor portfolio weight
        self.use_factor_portfolio_weight = use_factor_portfolio # TODO change to True
        if self.use_factor_portfolio_weight:
            self.factor_portfolio_weight = nn.Parameter(torch.zeros(n_factors))
            self.factor_portfolio_weight.data[0] = 1.0
            self.l1_norm_factor_portfolio_weight = nn.Parameter(torch.zeros(1))
            #self.factor_portfolio_weight._optim = {'lr': 0.03, 'weight_decay': 0.05}
            #self.l1_norm_factor_portfolio_weight._optim = {'lr': 0.03, 'weight_decay': 0.05}
            #self.register_parameter("factor_portfolio_weight", self.factor_portfolio_weight)
            #self.register_parameter("l1_norm_factor_portfolio_weight", self.l1_norm_factor_portfolio_weight)
    def forward(self, u, y):
        """
        Args:
            u: shape (B, H, M, T)
            y: shape (B, M, T)
        Returns:
            e:       (B, M, T)  residuals
            alpha:   (B, T, M, K) or (B, T, M, 1) in debug mode
            beta:    (M, K)
        """
        #print("Q norm", self.Q.norm(p=2))
        B, H, M, T = u.shape
        assert M == self.n_assets, f"Mismatch in #assets: {M} vs {self.n_assets}"
        assert H == self.d_x,      f"Mismatch in #features: {H} vs {self.d_x}"
        if self.no_factors:
            Y = y.permute(0, 2, 1)     # => (B, T, M)
            # set alpha and beta to zeros
            alpha = torch.zeros((B, T, M, 1), device=y.device)
            beta = torch.zeros((M, 1), device=y.device)
            e = Y 
            e  = e.permute(0, 2, 1).contiguous() 
            return e, alpha, beta

        # ==============
        # Debug path: single "market factor"
        # ==============
        if self.debug_market_factor:
            
            # We'll define alpha to be [1/M, 1/M, ..., 1/M] for every asset
            # (the uniform weighting). Or you might sum them for a "global market" factor.
            #alpha = torch.ones((B, T, M, 1), device=y.device) / M
            # Factor F_t = average of Y_t across M => shape (B, T, 1)
            #  "btm -> bt" but we'll keep the last dimension => (B, T, 1)
            # Actually, let's do (B, M, T) => average across M=1 => shape (B, T)
            # then unsqueeze => (B, T, 1)
            
            F_t = y.mean(dim=1, keepdim=False).unsqueeze(-1)

            # Beta is shape (M, 1), so BF => (B, T, M) if we do an einsum
            BF = torch.einsum("mk, btk -> btm", self.beta, F_t)  # => (B, T, M)
            # e => shape (B, T, M), then permute => (B, M, T)
            e = (y.permute(0, 2, 1) - BF).permute(0, 2, 1).contiguous()

            return e, alpha, self.beta

        # ==============
        # Normal multi-factor path
        # ==============
        X = u.permute(0, 3, 2, 1)  # => (B, T, M, H)
        Y = y.permute(0, 2, 1)     # => (B, T, M)

        alpha_list = []
        for k in range(self.n_factors):
            #K_k = torch.einsum('btmh,dh->btmd', X, self.W[k])
            attn_score = torch.einsum('btmd,d->btm', X, self.Q[k])
            scale = 10
            alpha_k = F.softmax(attn_score*scale, dim=2)  # (B, T, M)
            
            alpha_list.append(alpha_k)
        
        alpha = torch.stack(alpha_list, dim=-1)  # => (B, T, M, K)
        
        # === β⋆ implementation ===============================================
        # α_{t-1}: roll the time axis and zero-pad the first step
        alpha_tm1 = torch.roll(alpha, shifts=1, dims=1)          # (B,T,M,K)
        eps = 1e-3
        alpha_tm1[:, 0, :, :] = torch.ones_like(alpha_tm1[:, 0, :, :]) * eps

        # Gram G = αᵀα  → (B,T,K,K)
        G   = torch.einsum('btmk,btmj->btkj', alpha_tm1, alpha_tm1)
        eps = 1e-3
        
        I   = torch.eye(self.n_factors, device=G.device, dtype=G.dtype)
        G  += eps * I                                            # ridge for stability
        Ginv = torch.linalg.inv(G)                               # (B,T,K,K)
        
        # β⋆ = α G⁻¹  → (B,T,M,K)
        beta_star = torch.einsum('btmk,btkj->btmj', alpha_tm1, Ginv)

        # F_t = α_{t-1}ᵀ Y_t → (B,T,K)
        factor = torch.einsum('btmk,btm->btk', alpha_tm1, Y)

        # Ŷ_t = β⋆ F_t → (B,T,M)
        BF = torch.einsum('btmk,btk->btm', beta_star, factor)
        beta = beta_star
        e  = Y - BF
        e  = e.permute(0, 2, 1).contiguous()  # => (B, M, T)

        # Example usage
        # e has shape (B, M, T)
        # e_rolled will also be (B, M, T) but each entry is sum of up to 30 prior timesteps
        #e_rolled = rolling_cumulative_sum(e, window_size=30)
        e_rolled = e
        
        if self.use_factor_portfolio_weight:
            factor_portfolio = torch.einsum('btnk,k->btn', alpha_tm1, self.factor_portfolio_weight)
            # L1 normalize
            factor_portfolio = factor_portfolio / factor_portfolio.abs().sum(dim=-1, keepdim=True)
            factor_portfolio = factor_portfolio * self.l1_norm_factor_portfolio_weight
            #print("Short fraction:", (factor_portfolio < 0.0).sum() / factor_portfolio.numel())
            #print("L1 norm factor portfolio weight", self.l1_norm_factor_portfolio_weight.item())
            #print("Factor portfolio weights", self.factor_portfolio_weight.data)
            
        else:
            factor_portfolio = torch.zeros_like(e.permute(0, 2, 1))  # (B, nr_assets, nr_timesteps)

        return e_rolled, alpha, beta, factor_portfolio, self.l1_norm_factor_portfolio_weight.data


class AttentionFactors(nn.Module):
    def __init__(
            self,
            d_model,
            l_max=1024,
            channels=1,
            bidirectional=False,
            # Arguments for position-wise feedforward components
            activation='gelu', # activation between conv and FF
            postact='glu', # activation after FF
            initializer=None, # initializer on FF
            weight_norm=False, # weight normalization on FF
            dropout=0.0, tie_dropout=False,
            transposed=True, # axis ordering (B, L, D) or (B, D, L)
            verbose=False,
            block_fft_conv=False, # replace the FFT conv with Monarch blocks
            block_fft_conv_args={},
            use_gnn=False,
            use_small_gnn=False,
            use_layer_norm_gnn=True,
            use_gcn_true=True,
            use_sequence_layer_norm=True,
            gcn_depth=3,
            i_layer=0,
            nr_layers_with_gnn=6,
            use_set_mixing=False,
            nr_layers_with_set=6,
            set_mixing_architecture="MHA",
            set_mixing_dropout=0.0,
            set_debug=False,
            use_layer_norm_set=False,
            set_feature_embedding_dim=None,
            set_chunk_size=3,
            set_expand=2,
            set_projection=False,
            set_common_pool_embedding_dim=2,
            set_n_attn_summary_statistics=True,
            set_nr_attn_heads=4,
            set_var_layer_norm=False,
            set_v_dim=5,
            buildA_true=True,
            kernel_len=None,
            use_attention_factors=True,
            n_factors = 5,
            n_assets = 550,
            use_factor_portfolio=True,
            # SSM Kernel arguments
            **kernel_args,
        ):
        """
        d_state: the dimension of the state, also denoted by N
        l_max: the maximum kernel length, also denoted by L
        channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models
        bidirectional: if True, convolution kernel will be two-sided

        Position-wise feedforward components:
        --------------------
        activation: activation in between SS and FF
        postact: activation after FF ('id' for no activation, None to remove FF layer)
        initializer: initializer on FF
        weight_norm: weight normalization on FF
        dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d

        Other arguments:
        --------------------
        transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension]
        """

        super().__init__()
        if verbose:
            import src.utils.train
            log = src.utils.train.get_logger(__name__)
            log.info(f"Constructing Long Conv (H, L) = ({d_model}, {l_max})")
        
        self.d_model = d_model
        self.H = d_model
        self.L = l_max
        self.bidirectional = bidirectional
        self.channels = channels
        self.transposed = transposed
        self.block_fft_conv = block_fft_conv
        self.block_fft_conv_args = block_fft_conv_args
        self.kernel_len = kernel_len if kernel_len is not None else self.L



        self.D = nn.Parameter(torch.randn(channels, self.H))

        if self.bidirectional:
            channels *= 2

        # SSM Kernel
        # Test#self.kernel = LongConvKernel(self.H, L=self.L, channels=channels, verbose=verbose, **kernel_args)
        self.kernel = LongConvKernel(self.H, L=self.kernel_len, channels=channels, verbose=verbose, **kernel_args)
        
        if self.block_fft_conv:
            self.block_fft_u = BlockFFT(**self.block_fft_conv_args)
            self.block_fft_k = BlockFFT(**self.block_fft_conv_args)
            
        # Pointwise
        self.activation = Activation(activation)
        # dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout # Broken in torch==1.11
        dropout_fn = DropoutNd if tie_dropout else nn.Dropout
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # position-wise output transform to mix features
        if postact is None:
            self.output_linear = nn.Identity()
        else:
            self.output_linear = LinearActivation(
                self.d_model * self.channels,
                self.d_model,
                # self.H*self.channels,
                # self.d_model*(1 if self.gate is None else self.gate),
                transposed=self.transposed,
                initializer=initializer,
                activation=postact,
                activate=True,
                weight_norm=weight_norm,
            )
        
        # GNN mixing in the batch dimension
        self.use_sequence_layer_norm = use_sequence_layer_norm # True, if the layer norm is applied to the sequence dimension.
        self.use_gcn_true = use_gcn_true
        self.use_gnn = use_gnn
        self.small_gnn = use_small_gnn
        self.use_layer_norm_gnn = use_layer_norm_gnn
        self.gcn_depth = gcn_depth
        self.i_layer = i_layer
        self.nr_layers_with_gnn = nr_layers_with_gnn
        if self.use_gnn and self.i_layer <= self.nr_layers_with_gnn:
            
            from src.models.sequence.gnn import gtnet
            nr_timeseries = 10
            sequence_length = self.L -1
            feature_dim = self.d_model
            # Create a gtnet model
            # Not working

            self.gtnet_model = gtnet(
                gcn_true=self.use_gcn_true, 
                buildA_true=buildA_true, 
                gcn_depth=self.gcn_depth, 
                num_nodes=nr_timeseries, 
                device='cuda:0', 
                predefined_A=None, 
                static_feat=None, 
                dropout=0.3, 
                subgraph_size=nr_timeseries, #smaller or equal to num_nodes 
                node_dim=40, 
                dilation_exponential=1, 
                conv_channels=32, 
                residual_channels=32, 
                seq_length=sequence_length, 
                in_dim=feature_dim,  
                propalpha=0.05, 
                tanhalpha=3, 
                layer_norm_affline=True,
                use_layer_norm=self.use_layer_norm_gnn,
                use_sequence_layer_norm=self.use_sequence_layer_norm
            )
        self.use_attention_factors = use_attention_factors

        # Set mixing in the pool dimension:
        self.use_set_mixing = use_set_mixing
        self.nr_layers_with_set = nr_layers_with_set
        self.use_layer_norm_set = use_layer_norm_set
        self.set_mixing_architecture = set_mixing_architecture
        self.set_mixing_dropout = set_mixing_dropout
        self.set_embedding_dim = set_feature_embedding_dim
        self.set_chunk_size = set_chunk_size
        self.expand = set_expand
        self.set_projection = set_projection
        self.common_pool_embedding_dim = set_common_pool_embedding_dim
        self.set_debug = set_debug
        self.n_attn_summary_statistics = set_n_attn_summary_statistics
        self.set_nr_attn_heads = set_nr_attn_heads
        self.set_var_layer_norm = set_var_layer_norm
        self.set_v_dim = set_v_dim
        if self.set_embedding_dim is None:
            self.set_embedding_dim = self.d_model
        if self.use_set_mixing and self.i_layer <= self.nr_layers_with_set:
            import sys
            import os
            SAFARI_PATH = os.environ.get("SAFARI_PATH", None)
            sys.path.append(SAFARI_PATH)
            
            from src.tasks.encoders import SetEncoder
            self.set_encoder = SetEncoder(
                num_states=self.d_model-2,
                loan_pool_size=10,
                d_model=self.d_model,
                common_pool_embedding_dim=self.common_pool_embedding_dim,
                feature_embedding_dim=self.set_embedding_dim,
                debug=self.set_debug,
                architecture=self.set_mixing_architecture,
                use_layer_norm_set=self.use_layer_norm_set,
                chunk_size=self.set_chunk_size,
                nr_attention_heads=self.set_nr_attn_heads, #4
                n_attn_summary_statistics=self.n_attn_summary_statistics,
                dropout=self.set_mixing_dropout,
                expand=self.expand,
                projection=self.set_projection,
                set_var_layer_norm = self.set_var_layer_norm,
                #use_layer_norm_set=False
                )
        if self.use_attention_factors:
            # Implement attention factors
            self.use_factor_portfolio = use_factor_portfolio
            self.n_factors = n_factors
            self.n_assets = n_assets
            self.attention_factors = AttentionFactorModel(
                n_factors=self.n_factors,
                d_x=self.d_model,        
                d_attn=self.d_model,
                n_assets=self.n_assets,
                use_factor_portfolio=self.use_factor_portfolio,
                )



    def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask
        """
        u: (B H L) if self.transposed else (B L H)
        state: (H N) never needed, remnant from state spaces repo

        Returns: same shape as u
        """
        if self.use_gnn and self.i_layer <= self.nr_layers_with_gnn:
            u = self.gtnet_model(u)
        
        if self.use_set_mixing and self.i_layer <= self.nr_layers_with_set:
            #u shape: (B*nr_units, L, H)
            
            u = u.transpose(-1, -2)
            # u shape: (B*nr_units, H, L)
            #u = u.transpose(0,1)

            # (1, B*nr_units, H, L)
            #u = u.unsqueeze(0)  # TODO instead of just unsqueeze, we should use the batch size
            # (B, nr_units, H, L)
            u  = u.reshape(-1, kwargs["nr_units"], u.shape[1], u.shape[2])
            # (B, H, nr_units, L)
            u = u.transpose(1,2)
            # can the number of units go into the state variable? Let's try
            u, _ = self.set_encoder(u)
        if self.use_attention_factors:
            # The original u has shape (B*nr_units, nr_timesteps, H)
            y_orig = kwargs["y"] # (Batch_size, nr_units, nr_timesteps)
            
            u = u.transpose(-1, -2)
            u  = u.reshape(-1, kwargs["nr_units"], u.shape[1], u.shape[2])
            # (B, H, nr_units, L)
            u = u.transpose(1,2)
            
            # u has shape (B, H, nr_units, nr_timesteps)
            e, alpha, beta, factor_portfolio, l1_norm_factor_portfolio_weight = self.attention_factors(u, y_orig)
            per_stock_explained_variance = 1 - torch.var(e,dim=-1)/torch.var(y_orig,dim=-1)
            explained_variance = per_stock_explained_variance.mean()
            #explained_variance = 1 - torch.var(e)/torch.var(y_orig)
            # e has shape (B, nr_units, nr_timesteps)
            #reshape e to (B*nr_units, nr_timesteps)
            B = e.shape[0]
            e = e.reshape(B*kwargs["nr_units"], -1)
            #print("Mean residual", torch.abs(e).mean())
            # extend by repetition e to shape (B*nr_units, nr_timesteps, H)
            e_extended = e.unsqueeze(2).repeat(1, 1, u.shape[1])
            # e has shape (B*nr_units, nr_timesteps, H)
            u = e_extended
            
        if not self.transposed: u = u.transpose(-1, -2)
        L = u.size(-1)
        
        # Mask out padding tokens
        # TODO handle option for mask - instead of lengths, which assumes suffix padding
        if isinstance(lengths, int):
            if lengths != L:
                lengths = torch.tensor(lengths, dtype=torch.long, device=u.device)
            else:
                lengths = None
        if lengths is not None:
            assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)]
            mask = torch.where(torch.arange(L, device=lengths.device) < lengths[:, None, None], 1., 0.)
            u = u * mask

        # Compute SS Kernel
        L_kernel = L if self.L is None else min(L, round(self.L / rate))
        L_kernel = self.kernel_len
        k, _ =  self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L)
        # Test
        # Convolution
        if self.bidirectional:
            k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2)
            k = F.pad(k0, (0, L)) \
                    + F.pad(k1.flip(-1), (L, 0))

        if self.block_fft_conv:
            k_f = self.block_fft_k(k.to(torch.complex64), N=L_kernel+L) # (C H L)
            u_f = self.block_fft_u(u.to(torch.complex64), N=L_kernel+L) # (B H L)
            y_f = contract('bhl,chl->bchl', u_f, k_f)
            if self.learn_ifft:
                y = self.block_fft_u(y_f, N=L_kernel+L,forward=False).real[..., :L]
            else:
                y = torch.fft.ifft(y_f, n=L_kernel+L, dim=-1).real[..., :L] # (B C H L)
        else:
            k_f = torch.fft.rfft(k, n=L_kernel+L) # (C H L)
            u_f = torch.fft.rfft(u, n=L_kernel+L) # (B H L)
            y_f = contract('bhl,chl->bchl', u_f, k_f)
            y = torch.fft.irfft(y_f, n=L_kernel+L)[..., :L] # (B C H L)

        # Compute skip connection
        y = y + contract('bhl,ch->bchl', u, self.D)

        # Reshape to flatten channels
        y = rearrange(y, '... c h l -> ... (c h) l')

        if not self.transposed: y = y.transpose(-1, -2)
        y = self.activation(y)
        y = self.dropout(y)
        y = self.output_linear(y)
        # Assert y is not complex
        assert not torch.is_complex(y), f"y became complex: dtype={y.dtype}"

        if self.use_attention_factors:
            # Suppose `y` is the residual in shape (B*M, T, H)
            # We want to apply Phi_t^T = (I - alpha * beta^T) across the M dimension.

            # 1) Un-flatten the first dimension => (B, M, T, H)
                 # or however you store the batch size
            alpha_tm1 = alpha.clone()
            alpha_tm1[:, 1:, :, :] = alpha[:, :-1, :, :]
            alpha_tm1[:, 0, :, :] = 0.0

            M_ = kwargs["nr_units"]       # your number of assets
            B_ = y.shape[0]//M_
            T = y.shape[1]
            H = y.shape[2]
            y_reshaped = y.reshape(B_, M_, T, H)

            # 2) Permute so time is next to batch for alignment with alpha => (B, T, M, H)
            y_reshaped = y_reshaped.permute(0, 2, 1, 3)  # (B, T, M, H)

            # alpha is (B, T, M, K); beta is (M, K)

            # (A) Compute r = beta^T * y => shape (B, T, K, H)
            #     Here "mk, btmh -> btkh"
            #r = torch.einsum('mk, btmh->btkh', beta, y_reshaped)
            r  = torch.einsum('btmk,btmh->btkh', beta, y_reshaped) # new formulation

            # (B) Multiply by alpha across K => shape (B, T, M, H)
            #     "btmk, btkh -> btmh"
            ar = torch.einsum('btmk, btkh->btmh', alpha_tm1, r)

            # (C) Subtract => shape (B, T, M, H)
            y_asset = y_reshaped - ar

            # 3) Permute back to (B, M, T, H)
            y_asset = y_asset.permute(0, 2, 1, 3).contiguous()

            # 4) Flatten again to (B*M, T, H) so the shape matches your original y
            y_asset = y_asset.reshape(B_ * M_, T, H)

            # Now y_asset is in "asset space" after applying Phi_t^T
            y = y_asset

        return y, {"e": e, "explained_variance": explained_variance, "factor_portfolio": factor_portfolio, "l1_norm_factor_portfolio_weight": l1_norm_factor_portfolio_weight} #, "explained_variance_portfolio": explained_variance_portfolio}

    @property
    def d_state(self):
        return self.H

    @property
    def d_output(self):
        return self.d_model