###########################################
# This file contains the following:
# 1. Linear Transformer Model
# 2. Function for clipping gradient
# 3. Function for generating random data
#
# The notation for linear attention follows
# the paper at https://arxiv.org/pdf/2306.00297.pdf
###########################################

import math
import torch
from torch import nn
import numpy as np
import opt_einsum as oe
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Definition of a single linear attention unit for linear-regression data
# P is the value matrix
# Q is the product of key,query matrices
# the dimensions of the input are
# B: batch-size of prompts
# N: context length (excluding query)
# d: covariate dimension
# P,Q are d x d matrices
# Z is a B x (N+1) + (d+1) matrix
# Output is also B x (N+1) + (d+1)

# For linear attention, activation = None
# For standard attention, activation(x) = torch.nn.functional.softmax(x, dim = 2)
# For ReLU attention, activation(x) = torch.nn.relu(x)

class FeatureMap(nn.Module):
    """
    Parent feature map; default is identity function
    """
    def __init__(self, 
                 input_dim: int,                 
                 temp: int = None,
                 head_dim_idx: int = -1, 
                 eps: float = 1e-12, 
                 **kwargs: any):
        super().__init__()
        self.input_dim = input_dim
        self.head_dim_idx = head_dim_idx     
        self.temp = 1. if temp is None else temp
        self.eps = eps
        
    def forward(self, x: torch.Tensor):
        """
        Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
        """
        return x

    def expanded_size(self):
        return self.input_dim 


class TaylorExp(FeatureMap):
    """
    Feature map to compute 2nd-order Taylor approx. of exp(q^T k / sqrt(d))
    """
    def __init__(self, input_dim: int, **kwargs: any):
        super().__init__(input_dim, **kwargs)
        self.r2  = math.sqrt(2)
        self.rd  = math.sqrt(self.input_dim)
        self.rrd = math.sqrt(self.rd)
        self.tril_indices = torch.tril_indices(self.input_dim, self.input_dim, -1)
        
    # Running these in parallel
    def forward(self, x: torch.Tensor):
        # Get 2nd-order terms (rearrange(x * x), '... m n -> ... (m n)')
        x2 = (x.unsqueeze(-1) * x.unsqueeze(-2)).flatten(start_dim=-2) / self.r2
        return torch.cat([torch.ones(x[..., :1].shape).to(x.device), 
                          x / self.rrd, x2 / self.rd], dim=self.head_dim_idx)
        
    def forward_mem_save(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute f(x) s.t. f(x)^T f(x') = 1 + x^Tx' + (x^Tx')^2 / 2
        -> Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
        """
        # Slow but memory-saving way to compute 2nd-order terms; how do w/o outer-product first?
        x2  = oe.contract('...m,...n->...mn', x, x) / self.rd
        x2d = torch.diagonal(x2, dim1=-2, dim2=-1) / self.r2
        x2  = x2[..., self.tril_indices[0], self.tril_indices[1]]
        x   = torch.cat([torch.ones(x[..., :1].shape).to(x.device), 
                         x / self.rrd, x2d, x2], dim=-1)
        return x 


def attention(P,Q,Z, activation = None):
    B= Z.shape[0]
    N = Z.shape[1]-1
    d = Z.shape[2]-1
    P_full =  torch.cat([P,torch.zeros(1,d).to(device)],dim=0)
    P_full =  torch.cat([P_full,torch.zeros(d+1,1).to(device)],dim=1)
    P_full[d,d] = 1
    Q_full = torch.cat([Q, torch.zeros(1,d).to(device)],dim=0)
    Q_full = torch.cat([Q_full, torch.zeros(d+1,1).to(device)],dim=1)
    A = torch.eye(N+1).to(device)
    A[N,N] = 0
    Attn = torch.einsum('BNi, ij, BMj -> BNM', (Z,Q_full,Z))
    if activation is not None:
        Attn = activation(Attn)
    key = torch.einsum('ij, BNj -> BNi', (P_full,Z))
    Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn,A,key))
    return Output /N


def attention_B(P_full,QT_full,D_full,K_bias,V_bias,Z,activation):
    B= Z.shape[0]
    N = Z.shape[1]-1
    d = Z.shape[2]-1
    A = torch.eye(N+1).to(device)
    
    A[N,N] = 0
    query = torch.einsum('ij, BNj -> BNi', (QT_full,Z))
    key = torch.einsum("ij, BMj -> BMi", (D_full, Z))
    Attn_logit = torch.einsum("BNi, BMi -> BNM", (query, key))

    k_bias = K_bias.repeat(B,1,1)
    v_bias = V_bias.repeat(B,1,1)
    n_key = torch.cat((k_bias, key), dim = 1)
    Attn = torch.einsum("BNi, BMi -> BNM", (query, n_key))
    Attn_prob = torch.nn.functional.softmax(Attn, dim = 2)
    value = torch.einsum('ij, BNj -> BNi', (P_full,Z))
    value = torch.cat((v_bias, value), dim = 1)
    A = torch.eye(N+2).to(device)
    A[N,N] = 0
    Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value))

    return Output /N, Attn_logit, Attn, Attn_prob, query, key


def attention_C(P_full,QT_full,D_full,Gamma,Beta,Z,activation):
    B= Z.shape[0]
    N = Z.shape[1]-1
    d = Z.shape[2]-1
    A = torch.eye(N+1).to(device)
    
    A[N,N] = 0
    query = torch.einsum('ij, BNj -> BNi', (QT_full,Z))
    key = torch.einsum("ij, BMj -> BMi", (D_full, Z))
    Attn_logit = torch.einsum("BNi, BMi -> BNM", (query, key))

    if activation =="ln_relu":
        q_mean = query.mean(dim=-1, keepdim=True)
        q_var = ((query - q_mean) ** 2).mean(dim=-1, keepdim=True)
        q_std = (q_var + 1e-12).sqrt()
        ln_query = (query - q_mean) / q_std

        ln_query = ln_query * Gamma
        ln_query = ln_query + Beta

        k_mean = key.mean(dim=-1, keepdim=True)
        k_var = ((key - k_mean) ** 2).mean(dim=-1, keepdim=True)
        k_std = (k_var + 1e-12).sqrt()
        ln_key = (key - k_mean) / k_std  
              
        ln_key = ln_key * Gamma
        ln_key = ln_key + Beta

        ln_query = F.relu(ln_query)
        ln_key = F.relu(ln_key)

        Attn = torch.einsum("BNi, BMi -> BNM", (ln_query, ln_key))
        z = 1 / (torch.einsum("BMi,BNi->BM", ln_query, ln_key.cumsum(1)) + 1e-12)
        Attn_prob = Attn * z[..., None]
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT
    

    elif activation =="ln_softmax":
        q_mean = query.mean(dim=-1, keepdim=True)
        q_var = ((query - q_mean) ** 2).mean(dim=-1, keepdim=True)
        q_std = (q_var + 1e-12).sqrt()
        ln_query = (query - q_mean) / q_std

        ln_query = ln_query * Gamma
        ln_query = ln_query + Beta

        k_mean = key.mean(dim=-1, keepdim=True)
        k_var = ((key - k_mean) ** 2).mean(dim=-1, keepdim=True)
        k_std = (k_var + 1e-12).sqrt()
        ln_key = (key - k_mean) / k_std  
              
        ln_key = ln_key * Gamma
        ln_key = ln_key + Beta
        
        Attn = torch.einsum("BNi, BMi -> BNM", (ln_query, ln_key))
        Attn_prob = torch.nn.functional.softmax(Attn, dim = 2)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value))

    return Output /N, Attn_logit, Attn, Attn_prob, query, key



def attention_D(P_full,QT_full,D_full,Z,activation):
    B= Z.shape[0]
    N = Z.shape[1]-1
    d = Z.shape[2]-1
    # P_full =  torch.cat([P,torch.zeros(1,d).to(device)],dim=0)
    # P_full =  torch.cat([P_full,torch.zeros(d+1,1).to(device)],dim=1)
    # P_full[d,d] = 1 ######################################################
    # Q_full = torch.cat([Q, torch.zeros(1,d).to(device)],dim=0)
    # Q_full = torch.cat([Q_full, torch.zeros(d+1,1).to(device)],dim=1)
    # Q_full[0,-1].data = torch.Tensor([0.]) #####################################################
    # Q_full[0,-1].data = torch.Tensor([0.]) #####################################################
    # Q_full[-1,0].copy_(torch.tensor(0.)) #######################

    A = torch.eye(N+1).to(device)
    A[N,N] = 0
    query = torch.einsum('ij, BNj -> BNi', (QT_full,Z))
    key = torch.einsum("ij, BMj -> BMi", (D_full, Z))
    Attn_logit = torch.einsum("BNi, BMi -> BNM", (query, key))
    
    if activation =="relu":
        r_query = F.relu(query)
        r_key = F.relu(key)
        
        # r_query = torch.pow(r_query, 2)
        # r_key = torch.pow(r_key, 2)
        Attn = torch.einsum("BNi, BMi -> BNM", (r_query, r_key))
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT

    elif activation =="relu_new":

        Attn = torch.einsum("BNi, BMi -> BNM", (query, key))
        Attn = F.relu(Attn)
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT

    elif activation =="relu_square":
        r_query = F.relu(query)
        r_key = F.relu(key)
        r_query = torch.pow(r_query, 2)
        r_key = torch.pow(r_key, 2)

        Attn = torch.einsum("BNi, BMi -> BNM", (query, key))
        Attn = F.relu(Attn)
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT

    elif activation =="square":
        
        s_query = torch.pow(query, 2)
        s_key = torch.pow(key, 2)
        Attn = torch.einsum("BNi, BMi -> BNM", (s_query, s_key))
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT

    elif activation =="relu1":
        r_query = F.relu(query) + 1
        r_key = F.relu(key) + 1

        Attn = torch.einsum("BNi, BMi -> BNM", (r_query, r_key))
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT
    
    elif activation =="relu3":
        r_query = F.relu(query)+3
        r_key = F.relu(key)+3

        Attn = torch.einsum("BNi, BMi -> BNM", (r_query, r_key))
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT

    elif activation =="relu5":
        r_query = F.relu(query) +5
        r_key = F.relu(key)+5
        Attn = torch.einsum("BNi, BMi -> BNM", (r_query, r_key))
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT

    elif activation =="relu7":
        r_query = F.relu(query)+7
        r_key = F.relu(key)+7
        Attn = torch.einsum("BNi, BMi -> BNM", (r_query, r_key))
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT

    elif activation =="relu9":
        r_query = F.relu(query)+9
        r_key = F.relu(key)+9
        Attn = torch.einsum("BNi, BMi -> BNM", (r_query, r_key))
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT

    elif activation =="relu10":
        r_query = F.relu(query)+10
        r_key = F.relu(key)+10
        Attn = torch.einsum("BNi, BMi -> BNM", (r_query, r_key))
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT


    elif activation =="sep_softmax":
        e_query = torch.exp(query)
        e_key = torch.exp(key)
        Attn = torch.einsum("BNi, BMi -> BNM", (e_query, e_key))
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT
    
    elif activation =="softmax": 
        Attn2 = torch.einsum("BNi, BMi -> BNM", (query, key))
        Attn = torch.exp(Attn2)
        Attn_prob = torch.nn.functional.softmax(Attn2, dim = 2)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value))

    elif activation =="elu":
        el_query = F.elu(query) + 1
        el_key = F.elu(key) + 1
        Attn = torch.einsum("BNi, BMi -> BNM", (el_query, el_key))
        denom = torch.clamp_min(Attn.sum(dim=-1, keepdim=True), 1e-12)
        Attn_prob = Attn / denom
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT

    elif activation =="n_softmax": 
        Attn = torch.einsum("BNi, BMi -> BNM", (query, key))
        mean = Attn.mean(dim=-1, keepdim=True)
        std = Attn.std(dim=-1, keepdim=True) + 1e-12  # Add epsilon for numerical stability
        Attn_norm = (Attn - mean) / std
        Attn_prob = torch.nn.functional.softmax(Attn_norm, dim = 2)
        Attn_prob = torch.min(Attn_prob, torch.ones_like(Attn_prob))
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value))

    # elif activation =="nonorm":
    #     Attn = torch.einsum("BNi, BMi -> BNM", (query, key))
    #     Attn_prob = Attn
    #     # print('shape',Z.shape,Q_full.shape)
    #     # print('Attn',Attn)
    #     value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
    #     Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value))

    # elif activation =="nonorm_relu":
    #     query = F.relu(query)
    #     key = F.relu(key)
    #     Attn = torch.einsum("BNi, BMi -> BNM", (query, key))
    #     Attn_prob = Attn
    #     # print('shape',Z.shape,Q_full.shape)
    #     # print('Attn',Attn)
    #     value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
    #     Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value))
    

    return Output /N, Attn_logit, Attn, Attn_prob, query, key


def attention_A(P_full,QT_full,D_full,Gamma,Beta,Z,activation):
    B= Z.shape[0]
    N = Z.shape[1]-1
    d = Z.shape[2]-1
    A = torch.eye(N+1).to(device)
    
    A[N,N] = 0
    query = torch.einsum('ij, BNj -> BNi', (QT_full,Z))
    key = torch.einsum("ij, BMj -> BMi", (D_full, Z))
    Attn_logit = torch.einsum("BNi, BMi -> BNM", (query, key))

    if activation =="sn_relu":
        q_mean = query.mean(dim=-1, keepdim=True)
        q_var = ((query - q_mean) ** 2).mean(dim=-1, keepdim=True)
        q_std = (q_var + 1e-12).sqrt()
        ln_query = (query - q_mean) / q_std

        ln_query *= Gamma
        ln_query += Beta

        k_mean = key.mean(dim=-1, keepdim=True)
        k_var = ((key - k_mean) ** 2).mean(dim=-1, keepdim=True)
        k_std = (k_var + 1e-12).sqrt()
        ln_key = (key - k_mean) / k_std        
        ln_key *= Gamma
        ln_key += Beta
        Attn = torch.einsum("BNi, BMi -> BNM", (ln_query, ln_key))
        z = 1 / (torch.einsum("BMi,BNi->BM", ln_query, ln_key.cumsum(1)) + 1e-12)
        Attn_prob = Attn * z[..., None]
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))  #ZT PT
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT
    
    elif activation =="sn_softmax":
        q_mean = query.mean(dim=-1, keepdim=True)
        q_var = ((query - q_mean) ** 2).mean(dim=-1, keepdim=True)
        q_std = (q_var + 1e-12).sqrt()
        ln_query = (query - q_mean) / q_std        
        ln_query *= Gamma
        ln_query += Beta

        k_mean = key.mean(dim=-1, keepdim=True)
        k_var = ((key - k_mean) ** 2).mean(dim=-1, keepdim=True)
        k_std = (k_var + 1e-12).sqrt()
        ln_key = (key - k_mean) / k_std        
        ln_key *= Gamma
        ln_key += Beta
        Attn = torch.einsum("BNi, BMi -> BNM", (query, key))
        Attn_prob = torch.nn.functional.softmax(Attn, dim = 2)
        value = torch.einsum('ij, BNj -> BNi', (P_full,Z))
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value))

    return Output /N, Attn_logit, Attn, Attn_prob, query, key


def attention_E(Pi,Qi,Di,Z,activation):
    B= Z.shape[0]
    N = Z.shape[1]-1
    d = Z.shape[2]-1
    # P_full =  torch.cat([P,torch.zeros(1,d).to(device)],dim=0)
    # P_full =  torch.cat([P_full,torch.zeros(d+1,1).to(device)],dim=1)
    # P_full[d,d] = 1 ######################################################
    # Q_full = torch.cat([Q, torch.zeros(1,d).to(device)],dim=0)
    # Q_full = torch.cat([Q_full, torch.zeros(d+1,1).to(device)],dim=1)
    # Q_full[0,-1].data = torch.Tensor([0.]) #####################################################
    # Q_full[0,-1].data = torch.Tensor([0.]) #####################################################
    # Q_full[-1,0].copy_(torch.tensor(0.)) #######################

    A = torch.eye(N+1).to(device)
    
    A[N,N] = 0
    query = Pi
    key = Qi
    value = Di
    Attn_logit = torch.einsum("BNi, BMi -> BNM", (query, key))
    
    if activation =="s_relu":
        r_query = F.relu(query)
        r_key = F.relu(key)
        Attn = torch.einsum("BNi, BMi -> BNM", (r_query, r_key))
        z = 1 / (torch.einsum("BMi,BNi->BM", r_query, r_key.cumsum(1)) + 1e-12)
        Attn_prob = Attn * z[..., None]
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT
    
    elif activation =="s_sep_softmax":
        query = torch.exp(query)
        key = torch.exp(key)
        Attn = torch.einsum("BNi, BMi -> BNM", (query, key))
        z = 1 / (torch.einsum("BMi,BNi->BM", query, key.cumsum(1)) + 1e-12)
        Attn_prob = Attn * z[..., None]
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT
    
    elif activation =="s_softmax": 
        Attn = torch.einsum("BNi, BMi -> BNM", (query, key))
        Attn_prob = torch.nn.functional.softmax(Attn, dim = 2)
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value))

    elif activation =="n_softmax": 
        Attn = torch.einsum("BNi, BMi -> BNM", (query, key))
        mean = Attn.mean(dim=-1, keepdim=True)
        std = Attn.std(dim=-1, keepdim=True) + 1e-12  # Add epsilon for numerical stability
        Attn_norm = (Attn - mean) / std
        Attn_prob = torch.nn.functional.softmax(Attn_norm, dim = 2)
        Attn_prob = torch.min(Attn_prob, torch.ones_like(Attn_prob))

        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value))

    elif activation =="s_elu":
        query = F.elu(query) + 1
        key = F.elu(key) + 1
        Attn = torch.einsum("BNi, BMi -> BNM", (query, key))
        z = 1 / (torch.einsum("BMi,BNi->BM", query, key.cumsum(1)) + 1e-12)
        Attn_prob = Attn * z[..., None]
        # print('shape',Z.shape,Q_full.shape)
        # print('Attn',Attn)
        Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn_prob,A,value)) # A: mask ## ZT [Q] Z A ZT PT
    

    return Output/N, Attn_logit, Attn, Attn_prob, query, key

# The Linear Transformer module
# n_layer denotes the number of layers
# n_head denotes the number of heads. In most of our experiments, n_head = 1
# d denotes the dimension of covariates
# var denotes the variance of initialization. It needs to be sufficiently small, but exact value is not important
# allparam: contains all the parameters, has dimension n_layer x n_head x 2 x d x d
# For example
# - P matrix at layer i, head j is allparam[i,j,0,:,:]
# - Q matrix at layer i, head j is allparam[i,j,1,:,:]


class SpectralNormedWeight(nn.Module):
    """SpectralNorm Layer. First sigma uses SVD, then power iteration."""

    def __init__(
        self,
        weight: torch.Tensor,
    ):
        super().__init__()
        self.weight = weight
        with torch.no_grad():
            _, s, vh = torch.linalg.svd(self.weight, full_matrices=False)

        self.register_buffer("u", vh[0])
        self.register_buffer("spectral_norm", s[0] * torch.ones(1))

    def get_sigma(self, u: torch.Tensor, weight: torch.Tensor):
        with torch.no_grad():
            v = weight.mv(u)
            v = nn.functional.normalize(v, dim=0)
            u = weight.T.mv(v)
            u = nn.functional.normalize(u, dim=0)
            if self.training:
                self.u.data.copy_(u)

        return torch.einsum("c,cd,d->", v, weight, u)

    def forward(self):
        """Normalize by largest singular value and rescale by learnable."""
        sigma = self.get_sigma(u=self.u, weight=self.weight)
        if self.training:
            self.spectral_norm.data.copy_(sigma)

        return self.weight / sigma


class SNLinear(nn.Linear):
 
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
        init_multiplier: float = 1.0,
        stats_only: bool = False,
    ):
        super().__init__(in_features, out_features, bias=bias)
        self.stats_only = stats_only
        self.init_multiplier = init_multiplier

        self.init_std = 0.02 * init_multiplier
        nn.init.trunc_normal_(self.weight, std=self.init_std)

        # Handle normalization and add a learnable scalar.
        self.spectral_normed_weight = SpectralNormedWeight(self.weight)
        sn_init = self.spectral_normed_weight.spectral_norm

        # Would have set sigma to None if `stats_only` but jit really disliked this
        self.sigma = (
            torch.ones_like(sn_init)
            if self.stats_only
            else nn.Parameter(
                torch.zeros_like(sn_init).copy_(sn_init), requires_grad=True
            )
        )

        self.register_buffer("effective_spectral_norm", sn_init)
        self.update_effective_spec_norm()

    def update_effective_spec_norm(self):
        """Update the buffer corresponding to the spectral norm for tracking."""
        with torch.no_grad():
            s_0 = (
                self.spectral_normed_weight.spectral_norm
                if self.stats_only
                else self.sigma
            )
            self.effective_spectral_norm.data.copy_(s_0)

    def get_weight(self):
        """Get the reparameterized or reparameterized weight matrix depending on mode
        and update the external spectral norm tracker."""
        normed_weight = self.spectral_normed_weight()
        self.update_effective_spec_norm()
        return self.weight if self.stats_only else normed_weight * self.sigma

    def forward(self, inputs: torch.Tensor):
        weight = self.get_weight()
        return F.linear(inputs, weight, self.bias)


class Transformer_F(nn.Module):
    def __init__(self, n_layer, n_head, N, d, var):
        super(Transformer_F, self).__init__()
        self.register_parameter('allparam', torch.nn.Parameter(torch.zeros(n_layer, n_head, 2, d, d)))
        with torch.no_grad():
            self.allparam.normal_(0,var)
        self.n_layer = n_layer
        self.n_head = n_head

    def forward(self, Z):
        for i in range(self.n_layer):
            Zi = Z
            residues = 0
            # the forwarad map of each layer is given by F(Z) = Z + attention(Z)
            for j in range(self.n_head):
                Pij = self.allparam[i,j,0,:,:]
                Qij = self.allparam[i,j,1,:,:]
                residues = residues + attention(Pij,Qij,Zi)
            Z = Zi + residues
        return Z
    
    #enforces top-left-dxd-block sparsity on p
    def zero_p(self):
        for i in range(self.n_layer):
            for j in range(self.n_head):
                with torch.no_grad():
                    self.allparam[i,j,0,:,:].zero_()


class Transformer_A(Transformer_F):
    def __init__(self, n_layer, n_head, N, d, var, activation):
        super(Transformer_F, self).__init__()
        self.activation = activation
        self.Pij = SNLinear(d+1,d+1, bias=False, init_multiplier=1.0, stats_only=False)
        self.Qij = SNLinear(d+1,d+1, bias=False, init_multiplier=1.0, stats_only=False)
        self.Dij = SNLinear(d+1,d+1, bias=False, init_multiplier=1.0, stats_only=False)
        
        self.n_layer = n_layer
        self.n_head = n_head
    
    def forward(self, Z):
      
        attns_logit = []
        attns_prob = []
        q_norm = []
        k_norm = []
        inout_var = []

        for i in range(self.n_layer):
            Zi = Z
            Pi = self.Pij(Zi)
            Qi = self.Qij(Zi)
            Di = self.Dij(Zi)

            residues = 0
            
            attn_output, attn_logit, attn_prob, query, key = attention_E(Pi,Qi,Di,Zi,self.activation)
            residues = residues + attn_output
            in_out_vardiff = torch.var(Zi) - torch.var(attn_output)
            in_out_vardiff = abs(in_out_vardiff)
            
            Z = Zi + residues
            inout_var.append(in_out_vardiff)
            attns_logit.append(attn_logit)
            attns_prob.append(attn_prob)
            q_norm.append(query.norm().item())
            k_norm.append(key.norm().item())

        return Z, attns_logit, attns_prob, q_norm, k_norm, inout_var
    
    #enforces top-left-dxd-block sparsity on p
    def zero_p(self):
        for i in range(self.n_layer):
            for j in range(self.n_head):
                with torch.no_grad():
                    self.allparam[i,j,0,:,:].zero_()



class Transformer_D(Transformer_F):
    def __init__(self, n_layer, n_head, N, d, var, activation):
        super(Transformer_F, self).__init__()
        self.register_parameter('allparam', torch.nn.Parameter(torch.zeros(n_layer, n_head, 3, d+1, d+1)))

        with torch.no_grad():
            self.allparam.normal_(0,var)
        
        self.n_layer = n_layer
        self.n_head = n_head
        self.activation = activation

    def forward(self, Z):
        attns_logit = []
        attns_logit2 = []
        attns_prob = []
        q_norm = []
        k_norm = []
        inout_var = []
        inout_mean = []
        q_var = []
        k_var = []
        attn_vars = []
        attn_norms = []

        for i in range(self.n_layer):
            Zi = Z
            residues = 0
           
            # the forwarad map of each layer is given by F(Z) = Z + attention(Z)
            for j in range(self.n_head):
                Pij = self.allparam[i,j,0,:,:]
                Qij = self.allparam[i,j,1,:,:]
                Dij = self.allparam[i,j,2,:,:]
                attn_output, attn_logit, attn_logit2, attn_prob, query, key = attention_D(Pij,Qij,Dij,Zi,self.activation)
                in_out_vardiff = torch.var(Zi) - torch.var(attn_output)
                in_out_meandiff = torch.mean(Zi) - torch.mean(attn_output)
                attn_var = torch.var(attn_logit, dim =-1).flatten().detach().cpu().numpy().tolist()
                in_out_vardiff = abs(in_out_vardiff)
                in_out_meandiff = abs(in_out_meandiff)
                residues = residues + attn_output

            Z = Zi + residues
            attn_vars.append(attn_var[0])
            inout_var.append(in_out_vardiff.item())
            inout_mean.append(in_out_meandiff.item())
            attns_logit.append(attn_logit)
            attns_logit2.append(attn_logit2)
            attns_prob.append(attn_prob)
            q_norm.append(query.norm().item())
            k_norm.append(key.norm().item())
            q_var.append(torch.var(query).item())
            k_var.append(torch.var(key).item())
            attn_norm = torch.norm(attn_prob, p=2, dim = (1,2))
            print(attn_norm.shape)
            attn_norm = torch.mean(attn_norm).item()
            print(attn_norm)
            attn_norms.append(attn_norm)

        return Z, attns_logit, attns_logit2, attns_prob, q_norm, k_norm, inout_var, inout_mean, q_var, k_var, attn_vars, attn_norms
    
    #enforces top-left-dxd-block sparsity on p
    def zero_p(self):
        for i in range(self.n_layer):
            for j in range(self.n_head):
                with torch.no_grad():
                    self.allparam[i,j,0,:,:].zero_()


class Transformer_C(nn.Module):
    def __init__(self, n_layer, n_head, N, d, var, activation):
        super(Transformer_C, self).__init__()
        self.register_parameter('allparam', torch.nn.Parameter(torch.zeros(n_layer, n_head, 3, d+1, d+1), requires_grad=True))
        self.register_parameter('gamma', torch.nn.Parameter(torch.ones(n_layer, n_head, 1, d+1,), requires_grad=True))
        self.register_parameter('beta', torch.nn.Parameter(torch.zeros(n_layer, n_head, 1, d+1,), requires_grad=True))
       
        with torch.no_grad():
            self.allparam.normal_(0,var)
            #self.lnparam.normal_(0,var)
        self.n_layer = n_layer
        self.n_head = n_head
        self.activation = activation

    def forward(self, Z):
        attns_logit = []
        attns_logit2 = []
        attns_prob = []
        q_norm = []
        k_norm = []
        inout_var = []
        inout_mean = []
        q_var = []
        k_var = []
        attn_vars = []
        attn_norms = []

        for i in range(self.n_layer):
            Zi = Z
            residues = 0
            
            # the forwarad map of each layer is given by F(Z) = Z + attention(Z)
            for j in range(self.n_head):
                Pij = self.allparam[i,j,0,:,:]
                Qij = self.allparam[i,j,1,:,:]
                Dij = self.allparam[i,j,2,:,:]

                # For LayerNorm
                Gij = self.gamma[i,j,:,:]
                Bij = self.beta[i,j,:,:]

                attn_output, attn_logit, attn_logit2, attn_prob, query, key = attention_C(Pij,Qij,Dij,Gij,Bij,Zi,self.activation)
                in_out_vardiff = torch.var(Zi) - torch.var(attn_output)
                in_out_meandiff = torch.mean(Zi) - torch.mean(attn_output)
                attn_var = torch.var(attn_logit, dim =-1).flatten().detach().cpu().numpy().tolist()
                in_out_vardiff = abs(in_out_vardiff)
                in_out_meandiff = abs(in_out_meandiff)
                residues = residues + attn_output

            Z = Zi + residues
            attn_vars.append(attn_var[0])
            inout_var.append(in_out_vardiff.item())
            inout_mean.append(in_out_meandiff.item())
            attns_logit.append(attn_logit)
            attns_logit2.append(attn_logit2)
            attns_prob.append(attn_prob)
            q_norm.append(query.norm().item())
            k_norm.append(key.norm().item())
            q_var.append(torch.var(query).item())
            k_var.append(torch.var(key).item())
            attn_norm = torch.norm(attn_prob, p=2, dim = (1,2))
            print(attn_norm.shape)
            attn_norm = torch.mean(attn_norm).item()
            print(attn_norm)
            attn_norms.append(attn_norm)

        return Z, attns_logit, attns_logit2, attns_prob, q_norm, k_norm, inout_var, inout_mean, q_var, k_var, attn_vars, attn_norms
    #enforces top-left-dxd-block sparsity on p
    def zero_p(self):
        for i in range(self.n_layer):
            for j in range(self.n_head):
                with torch.no_grad():
                    self.allparam[i,j,0,:,:].zero_()


class Transformer_B(Transformer_F):
    def __init__(self, n_layer, n_head, N, d, var, activation):
        super(Transformer_F, self).__init__()
        self.register_parameter('allparam', torch.nn.Parameter(torch.zeros(n_layer, n_head, 3, d+1, d+1)))
        self.register_parameter('biasparam', torch.nn.Parameter(torch.zeros(n_layer, n_head, 2, 1, d+1)))
        with torch.no_grad():
            self.allparam.normal_(0,var)
            #self.lnparam.normal_(0,var)
        self.n_layer = n_layer
        self.n_head = n_head
        self.activation = activation

    def forward(self, Z):
        attns_logit = []
        attns_logit2 = []
        attns_prob = []
        q_norm = []
        k_norm = []
        inout_var = []
        
        for i in range(self.n_layer):
            Zi = Z
            residues = 0
        
            # the forwarad map of each layer is given by F(Z) = Z + attention(Z)
            for j in range(self.n_head):
                Pij = self.allparam[i,j,0,:,:]
                Qij = self.allparam[i,j,1,:,:]
                Dij = self.allparam[i,j,2,:,:]

                # For Bias
                Cij = self.biasparam[i,j,0,:,:]
                Vij = self.biasparam[i,j,1,:,:]
                attn_output, attn_logit, attn_logit2, attn_prob, query, key = attention_B(Pij,Qij,Dij,Cij,Vij,Zi,self.activation)
                in_out_vardiff = torch.var(Zi) - torch.var(attn_output)
                in_out_vardiff = abs(in_out_vardiff)
                residues = residues + attn_output

            Z = Zi + residues
            inout_var.append(in_out_vardiff.item())
            attns_logit.append(attn_logit)
            attns_logit2.append(attn_logit2)
            attns_prob.append(attn_prob)
            q_norm.append(query.norm().item())
            k_norm.append(key.norm().item())

        return Z, attns_logit, attns_logit2, attns_prob, q_norm, k_norm, inout_var
    
    #enforces top-left-dxd-block sparsity on p
    def zero_p(self):
        for i in range(self.n_layer):
            for j in range(self.n_head):
                with torch.no_grad():
                    self.allparam[i,j,0,:,:].zero_()


def get_attn_ent(p):

        plogp = p * torch.log(p)
        plogp[p == 0] = 0

        return -plogp.sum(dim=-1)

def get_mean_skewness(attention_weights):

    attn_mean = np.mean(attention_weights, axis=-1, keepdims=True)
    attn_diff = attention_weights - attn_mean
    attn_var = np.mean(np.power(attn_diff, 2.0), axis=-1, keepdims=True)
    attn_std = np.power(attn_var, 0.5) + 1e-6
    attn_zscores = attn_diff / attn_std
    attn_skews = np.mean(np.power(attn_zscores, 3.0), axis=-1)

    mean_skews = np.mean(attn_skews, axis=-1)  

    return torch.tensor(mean_skews)


def get_distance_matrix(seq_len):

    # for generate distance matrix
    distance_matrix = np.zeros((seq_len, seq_len))
    for i in range(seq_len):
        for j in range(seq_len):
            if i == j:  # same position of tokens are zero value
                continue

            xi, yi = (int(i / seq_len)), (i % seq_len)
            xj, yj = (int(j / seq_len)), (j % seq_len)

            distance_matrix[i, j] = np.linalg.norm([xi - xj, yi - yj])
            #distance_matrix[i, j] = ([xi - xj, yi - yj])

    return distance_matrix


def get_mean_attention_dist(seq_len, attention_weights):
    # The attention_weights shape = (batch_size, num_heads, seq_len, seq_len)

    distance_matrix = get_distance_matrix(seq_len)

    # distance_matrix shape = (seq_len, seq_len)

    seq_len, seq_len = distance_matrix.shape
    #attention_weights = attention_weights.permute(2, 3, )
    distance_matrix = distance_matrix.reshape((1, 1, seq_len, seq_len))

    mean_distances = attention_weights * distance_matrix

    mean_distances = np.sum(mean_distances, axis = -1)  # sum dim(-1) to get average distance per token
    # mean_distances shape = (batch_size, num_heads, seq_len-> summation of vector)

    mean_distances = np.mean(mean_distances, axis = -1)  # average all the tokens
    # mean_distances shape = (batch_size, num_heads)

    return torch.tensor(mean_distances)

def kl_divergence_with_uniform(attention_probs):
    # attention_probs shape: Batch x Head x Seq_len x Seq_len
    batch_size, seq_len, seq_len = attention_probs.shape
    
    # Uniform distribution for each row (Seq_len x Seq_len)
    uniform_dist = torch.full_like(attention_probs, 1.0 / seq_len)
    
    # Attention probabilities should not be 0 to avoid log(0), clamp for numerical stability
    attention_probs = attention_probs.clamp(min=1e-9)
    
    # KL divergence computation (element-wise)
    kl_div = attention_probs * (torch.log(attention_probs) - torch.log(uniform_dist))
    
    # Sum over the last dimension (sequence length) to get KL divergence per row
    kl_div = kl_div.sum(dim=-1).clamp(min=0) # Sum over the last dimension (seq_len)
    kl_div = kl_div.mean().item()
    return kl_div  # Returns a tensor of shape Batch x Head x Seq_len


# evaluate the loss of model, given data (Z,y)
def in_context_loss(model, Z, y):
    N = Z.shape[1]-1
    d = Z.shape[2]-1
    output, attns_logit, attns_logit2, attns_prob, q_norm, k_norm, inout_var, inout_mean, q_var, k_var, attn_var, attn_norm = model(Z)
    # attn_Flogit_max = torch.max(attns_logit[0])
    # attn_Llogit_max = torch.max(attns_logit[-1])
    # attn_Flogit_min = torch.min(attns_logit[0])
    # attn_Llogit_min = torch.min(attns_logit[-1])

    attn_Flogit_max = 0
    attn_Llogit_max = 0
    attn_Flogit_min = 0
    attn_Llogit_min = 0
    
    # first_attns_logit = attns_logit[0]
    # first_attns_logit2 = attns_logit2[0]
    first_attns_logit = 0
    first_attns_logit2 = 0

    # attn_qk = first_attns_logit[0,0,:]
    # attn_qk2 = first_attns_logit2[0,0,:]
    attn_qk = 0
    attn_qk2 = 0

    attn_entropy = []
    attn_entropy_std = []
    attn_ent_tmp = []
    kl_mean = []
    for i in range(len(attns_prob)):
         attn_entropy.append(torch.mean(torch.mean(get_attn_ent(attns_prob[i]), dim=-1), dim=0))
         attn_entropy_std.append(torch.mean(torch.std(get_attn_ent(attns_prob[i]), dim=-1), dim=0))
         attn_ent_flat = get_attn_ent(attns_prob[i]).flatten().detach().cpu().numpy().tolist()
         attn_ent_tmp.append(attn_ent_flat)
    for i in range(len(attn_entropy)):
        attn_entropy[i] = attn_entropy[i].item()
        attn_entropy_std[i] = attn_entropy_std[i].item()
    # attn_entropy = 0
    # attn_entropy = 0
    attn_f_ent_mean = attn_entropy
    attn_l_ent_mean = attn_entropy
    #attn_f_ent_mean = 0
    #attn_l_ent_mean = 0

    attn_f_ent_skew = torch.mean(torch.mean(get_mean_skewness(attns_prob[0].cpu().float().detach().numpy())), dim=0)
    attn_l_ent_skew = torch.mean(torch.mean(get_mean_skewness(attns_prob[-1].cpu().float().detach().numpy())), dim=0)

    # attn_f_ent_skew = 0
    # attn_l_ent_skew = 0

    kl_mean = [kl_divergence_with_uniform(weights) for weights in attns_prob]
    #kl_mean = 0

    # attn_f_ent_dist = torch.mean(torch.mean(get_mean_attention_dist(N+1, attns_prob[0].cpu().float().detach().numpy())), dim=0)
    # attn_l_ent_dist = torch.mean(torch.mean(get_mean_attention_dist(N+1, attns_prob[-1].cpu().float().detach().numpy())), dim=0)
    attn_f_ent_dist = 0
    attn_l_ent_dist = 0
    # attn_entropy_std = 0
    attn_f_ent_std = attn_entropy_std
    attn_l_ent_std = attn_entropy_std
    #attn_f_ent_std = 0
    #attn_l_ent_std = 0

    # attn_f_ent_max = torch.max(get_attn_ent(attns_prob[0]))
    # attn_l_ent_max = torch.max(get_attn_ent(attns_prob[-1]))
    attn_f_ent_max = 0
    attn_l_ent_max = 0

    # attn_f_ent_min = torch.min(get_attn_ent(attns_prob[0]))
    # attn_l_ent_min = torch.min(get_attn_ent(attns_prob[-1]))
    attn_f_ent_min = 0
    attn_l_ent_min = 0

    # attn_f_prob_diff = torch.max(attns_prob[0]) - torch.min(attns_prob[0])
    # attn_l_prob_diff = torch.max(attns_prob[-1]) - torch.min(attns_prob[-1])
    attn_f_prob_diff = 0
    attn_l_prob_diff = 0

    diff = output[:,N,d]+y
    loss = ((diff)**2).mean()

    return loss, attns_logit, attn_qk, attn_qk2, attns_prob, q_norm, k_norm, q_var, k_var, inout_var, inout_mean, attn_f_ent_skew, attn_l_ent_skew, attn_f_ent_dist, attn_l_ent_dist, attn_f_prob_diff, attn_l_prob_diff, attn_Flogit_max, attn_Llogit_max, attn_Flogit_min, attn_Llogit_min, attn_f_ent_mean, \
                             attn_l_ent_mean, attn_f_ent_std, attn_l_ent_std, attn_f_ent_max, attn_l_ent_max, attn_f_ent_min, attn_l_ent_min, attn_var, attn_ent_tmp, attn_norm, kl_mean
        
# generate random data for linear regression
# mode: distribution of samples to generate. Currently supports 'normal', 'gamma', 'sphere'
# N: number of context examples
# d: dimension of covariates
# For gamma distribution:
# - shape_k: shape parameter of gamma distribution (unused otherwise)
# - scale parameter: hard coded so that when shape_k = 5/2 and d=5, the generated data is standard normal



def generate_data2(var, mode='normal',N=20,d=1,B=1000,shape_k=0.1, U=None, D=None):
    W= torch.FloatTensor(B, d).normal_(0,var).to(device)
    X = torch.FloatTensor(B, N, d).normal_(0, var).to(device)
    X_test = torch.FloatTensor(B,1,d).normal_(0, var).to(device)
    
    if U is not None:
        U = U.to(device)
        D = D.to(device)
        W= torch.FloatTensor(B, d).normal_(0,1).to(device)
        W = torch.mm(W,torch.inverse(D))
        W = torch.mm(W,U.t())
    
    if mode =='sphere':
        X.div_(X.norm(p=2,dim=2)[:,:,None])
        X_test.div_(X_test.norm(p=2,dim=2)[:,:,None])
    elif mode == 'gamma':
        # random gamma scaling for X
        gamma_scales = np.random.gamma(shape=shape_k, scale=(10/shape_k)**(0.5), size=[B,N])
        gamma_scales = torch.Tensor(gamma_scales).to(device)
        gamma_scales = gamma_scales.sqrt()
        # random gamma scaling for X_test
        gamma_test_scales = np.random.gamma(shape=shape_k, scale=(10/shape_k)**(0.5), size=[B,1])
        gamma_test_scales = torch.Tensor(gamma_test_scales).to(device)
        gamma_test_scales = gamma_test_scales.sqrt()
        # normalize to unit norm
        X.div_(X.norm(p=2,dim=2)[:,:,None])
        X_test.div_(X_test.norm(p=2,dim=2)[:,:,None])
        # scale by gamma
        X.mul_(gamma_scales[:,:,None])
        X_test.mul_(gamma_test_scales[:,:,None])
    elif mode =='normal':
        assert True
    elif mode == 'relu':
        return generate_data_relu(N=N, d=d, B=B, hidden_dim=d)
    elif mode == 'mlp':
        generate_data_mlp(N=N, d=d, B=B, hidden_dim=d)
    else:
        assert False
        
    if U is not None:
        X = torch.einsum('ij, jk, BNk -> BNi', (U,D,X))
        X_test = torch.einsum('ij, jk, BNk -> BNi', (U,D,X_test))
        
    y = torch.einsum('bi,bni->bn', (W, X)).unsqueeze(2)
    y_zero = torch.zeros(B,1,1).to(device)
    y_test = torch.einsum('bi,bni->bn', (W, X_test)).squeeze(1)
    X_comb= torch.cat([X,X_test],dim=1)
    y_comb= torch.cat([y,y_zero],dim=1)
    Z= torch.cat([X_comb,y_comb],dim=2)
    return Z.to(device),y_test.to(device)


def generate_data_inplace(Z, U=None, D=None):
    
    
    B = Z.shape[0]
    N = Z.shape[1]-1
    d = Z.shape[2]-1
    X = Z[:,:,0:-1]
    X.normal_(0, 1).to(device)
    W= torch.FloatTensor(B, d).normal_(0,1).to(device)
    if U is not None:
        U = U.to(device)
        D = D.to(device)
        W = torch.mm(W,torch.inverse(D))
        W = torch.mm(W,U.t())
        Z[:,:,0:-1] = torch.einsum('ij, jk, BNk -> BNi', (U,D,X))
        
    Z[:,:,-1] = torch.einsum('bi,bni->bn', (W, Z[:,:,0:-1])) #y update
    y_test = Z[:,-1,-1].detach().clone()
    Z[:,-1,-1].zero_()
    return Z.to(device),y_test.to(device)

def generate_data_sine(N=10, B=1000):
    # Sample amplitude a and phase p for each task
    a = torch.FloatTensor(B).uniform_(0.1, 5).to(device)
    p = torch.FloatTensor(B).uniform_(0, math.pi).to(device)
 
    X = torch.FloatTensor(B, N).uniform_(-5, 5).to(device)
 
    Y = a.unsqueeze(1) * torch.sin(p.unsqueeze(1) + X)
 
    X = X.unsqueeze(-1)
    Y = Y.unsqueeze(-1)

    return X, Y

def generate_data_relu(mode='normal', N=20, d=1, B=1000, shape_k=0.1, U=None, D=None, hidden_dim=100):
    # Generate random input data
    X = torch.FloatTensor(B, N, d).normal_(0, 1).to(device)
    X_test = torch.FloatTensor(B, 1, d).normal_(0, 1).to(device)

    # Additional transformations if mode is 'sphere' or 'gamma' [Similar to the existing generate_data function]

    # Define a 1-hidden layer ReLU network
    model = nn.Sequential(
        nn.Linear(d, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, 1)
    ).to(device)
    model[0].weight.data.normal_(0, 0.1)
    model[2].weight.data.normal_(0, 0.1)

    # Generate y values using the ReLU network
    y = model(X.view(-1, d)).view(B, N, 1)
    y_test = model(X_test.view(-1, d)).view(B, 1).squeeze(1)
 
    y_zero = torch.zeros(B, 1, 1).to(device)
    X_comb = torch.cat([X, X_test], dim=1)
    y_comb = torch.cat([y, y_zero], dim=1)
    Z = torch.cat([X_comb, y_comb], dim=2)

    return Z, y_test

def generate_data_mlp(N=20, d=1, B=1000, hidden_dim=100):
    # Generate random input data
    X = torch.FloatTensor(B, N, d).normal_(0, 1).to(device)
    X_test = torch.FloatTensor(B, 1, d).normal_(0, 1).to(device)

    # Additional transformations if mode is 'sphere' or 'gamma' [Similar to the existing generate_data function]

    # Define a 1-hidden layer ReLU network
    model = nn.Sequential(
        nn.Linear(d, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, d)
    ).to(device)
    model[0].weight.data.normal_(0, 1)
    model[2].weight.data.normal_(0, 1)

    X_MLP = model(X.view(-1, d)).view(B, N, d)
    X_test_MLP = model(X_test.view(-1, d)).view(B, 1, d)

    W = torch.FloatTensor(B, d).normal_(0,1).to(device)
    y = torch.einsum('bi,bni->bn', (W, X_MLP)).unsqueeze(2)
    y_zero = torch.zeros(B,1,1).to(device)
    y_test = torch.einsum('bi,bni->bn', (W, X_test_MLP)).squeeze(1)
    X_comb= torch.cat([X_MLP,X_test_MLP],dim=1)
    y_comb= torch.cat([y,y_zero],dim=1)
    Z= torch.cat([X_comb,y_comb],dim=2)

    return Z, y_test
