# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F

from APLinear import APLinearSel
from plugin import *
from flash_attn import flash_attn_with_kvcache

import transformers
import nvtx

from tqdm import tqdm

def find_multiple(n: int, k: int) -> int:
    if n % k == 0:
        return n
    return n + k - (n % k)

@dataclass
class ModelArgs:
    block_size: int = 2048
    vocab_size: int = 32000
    n_layer: int = 32
    n_head: int = 32
    dim: int = 4096
    intermediate_size: int = None
    n_local_heads: int = -1
    head_dim: int = 64
    rope_base: float = 10000
    norm_eps: float = 1e-5
    rope_scaling: Optional[dict] = None
    model_name: Optional[str] = None

    def __post_init__(self):
        if self.n_local_heads == -1:
            self.n_local_heads = self.n_head
        if self.intermediate_size is None:
            hidden_dim = 4 * self.dim
            n_hidden = int(2 * hidden_dim / 3)
            self.intermediate_size = find_multiple(n_hidden, 256)
        self.head_dim = self.dim // self.n_head

    @classmethod
    def from_name(cls, name: str):
        assert name in transformer_configs, f"Unknown model name: {name}, available: {transformer_configs.keys()}"
        return cls(**transformer_configs[name])
        """
        # fuzzy search
        config = [config for config in transformer_configs if config.lower() in str(name).lower()]

        # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
        # take longer name (as it have more symbols matched)
        if len(config) > 1:
            config.sort(key=len, reverse=True)
            assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
            
        return cls(**transformer_configs[config[0]])
        """


transformer_configs = {
    "Meta-Llama-2-7B": dict(model_name="Meta-Llama-2-7B", block_size=2048, n_layer=32, n_head=32, n_local_heads=-1, dim=4096, intermediate_size=11008, vocab_size=32000, rope_base=10000),
    "Meta-Llama-3-8B-Instruct": dict(model_name="Meta-Llama-3-8B-Instruct", block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
    "Phi-3-medium-4k-instruct": dict(model_name='Phi-3-medium-4k-instruct', block_size=4096, n_layer=40, n_head=40, n_local_heads=10, dim=5120, intermediate_size=17920, vocab_size=32064, rope_base=10000),
    "Mistral-7B-v0.3": dict(model_name='Mistral-7b-v0.3', n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
    "Qwen2.5-14B": dict(model_name="Qwen2.5-14B", block_size=2048, n_layer=48, n_head=40, n_local_heads=8, dim=5120, intermediate_size=13824, vocab_size=131072, rope_base=10000),
}

class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.half):
        super().__init__()
        cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
        #cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: [S], k_val: [B, H, S, D]
        assert input_pos.shape[0] == k_val.shape[2]

        k_out = self.k_cache
        v_out = self.v_cache
        k_out[:, :, input_pos] = k_val
        v_out[:, :, input_pos] = v_val

        return k_out, v_out

def parse_config(s:str, n:int) -> dict:
    s = s.strip("\n")
    targ_dict = {}

    targ_l = {"q_proj":-1,"k_proj":-1,"v_proj":-1,"o_proj":-1,"gate_proj":-1,"up_proj":-1,"down_proj":-1,}

    indx = s.find("all")
    if indx != -1:
        targ = eval(s[indx+len("all"):])
        for key in targ_l.keys(): targ_l[key] = targ
    
    sarr = s.split("+")
    for ss in sarr:
        try:
            targ = eval(ss[-2:])
        except:
            targ = eval(ss[-1:])
        for name in targ_l.keys():
            head = name[0]
            if ss.find(head) != -1:
                if targ_l[name] != -1 : raise RuntimeError(f"Double config for {name}")
                targ_l[name] = targ
    
    for name in targ_l.keys():
        if targ_l[name] == -1: raise RuntimeError(f"config not set for {name}")
        for i in range(n+1):
            targ = 1000000.0
            if i >= targ_l[name]:
                targ = -1000000.0
            targ_dict[(i,name)] = targ
    
    print(targ_l)

    return targ_dict

def parse_config_fused(s:str, n:int) -> dict:
    s = s.strip("\n")
    targ_dict = {}

    targ_l = {"qkv_proj":-1,"o_proj":-1,"gate_up_proj":-1,"down_proj":-1,}

    indx = s.find("all")
    if indx != -1:
        targ = eval(s[indx+len("all"):])
        for key in targ_l.keys(): targ_l[key] = targ
    
    sarr = s.split("+")
    for ss in sarr:
        try:
            targ = eval(ss[-2:])
        except:
            targ = eval(ss[-1:])
        for name in targ_l.keys():
            head = name[0]
            if ss.find(head) != -1:
                if targ_l[name] != -1 : raise RuntimeError(f"Double config for {name}")
                targ_l[name] = targ
    
    for name in targ_l.keys():
        if targ_l[name] == -1: raise RuntimeError(f"config not set for {name}")
        for i in range(n+1):
            targ = 1000000.0
            if i >= targ_l[name]:
                targ = -1000000.0
            targ_dict[(i,name)] = targ
    
    print(targ_l)

    return targ_dict


class Transformer(nn.Module):
    def __init__(self, dtype, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None, halve_layers=False) -> None:
        super().__init__()
        self.config = config
        self.dtype = dtype

        prec_arr = [3,4,5,6]

        # if halve_layers, halve the number of layers for testing purposes
        if halve_layers:
            config.n_layer = config.n_layer // 2

        with open("config/bsel_config.txt") as f:
            bsel_config = f.readline()
            
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        if "phi" in config.model_name.lower():
            self.corr_arr_path = f"config/phi-3-medium_corr_arr_0.9.pt"
            self.th_arr_path = f"config/phi-3-medium_th_arr.pt"
            self.th_layerbits_path = f"config/phi-3-medium_th_arr_layerbits.pt"
            self.corr_arr= torch.load(self.corr_arr_path)
            self.corr_dict = {}
            th_arr, max_mem_dict = torch.load(self.th_arr_path)
            bh_yes_arr= torch.load(self.th_layerbits_path)

            self.b_dict = {}
            for (l, n, slope, inter, _, b_l, b_h) in self.corr_arr:
                self.corr_dict[(l,n)] = (slope.item(), inter.item())
                
            self.th_dict = {}
            targ_dict = {}
            self.module_arr = ["qkv_proj","o_proj","gate_up_proj","down_proj"]
            for l in range(config.n_layer):
                for n in self.module_arr:
                    th = th_arr[l*len(self.module_arr)+self.module_arr.index(n)]
                    bh_yes = bh_yes_arr[l*len(self.module_arr)+self.module_arr.index(n)]
                    maxmem = max_mem_dict[(l, n)]
                    if maxmem == 6:
                        if th < -0.5:
                            b_l = prec_arr[0]
                            b_h = prec_arr[1]
                            th = -th-0.5
                        elif th < 0.5:
                            b_l = prec_arr[1]
                            b_h = prec_arr[2]
                            th = -th+0.5
                        else:
                            b_l = prec_arr[2]
                            b_h = prec_arr[3]
                            th= -th+1.5
                    elif maxmem == 5:
                        if th < 0:
                            b_l = prec_arr[0]
                            b_h = prec_arr[1]
                            th = -th
                        else:
                            b_l = prec_arr[1]
                            b_h = prec_arr[2]
                            th = 1-th
                    elif maxmem == 4:
                        b_l = prec_arr[0]
                        b_h = prec_arr[1]
                        th = 1-th
                    elif maxmem == 3:
                        b_l = prec_arr[0]
                        b_h = prec_arr[0]
                    
                    th = torch.tensor(th, dtype=float, device="cuda")
                    self.th_dict[(l,n)] = th
                    self.b_dict[(l,n)] = (torch.tensor(b_l, dtype=torch.int, device="cuda"), 
                                        torch.tensor(b_h, dtype=torch.int, device="cuda"))

                    targ_dict[(l,n)] = -1000000.0 if bh_yes == 1 else 1000000.0

            # Exception handling for last layer
            for l in [config.n_layer]:
                for n in self.module_arr:
                    targ_dict[(l,n)] = 0.5
                    self.b_dict[(l,n)] = (torch.tensor(-1, dtype=torch.int, device="cuda"), 
                                        torch.tensor(-1, dtype=torch.int, device="cuda"))

            self.layers = nn.ModuleList(TransformerBlock_fused(config, linear_class, linear_kwargs, 
                                                         intra_class_dict={'gate_up':(selectorbg_gu_linreg if ((l,"gate_up_proj") in self.corr_dict.keys())
                                                                                   else selectorbg_gu_gemv)},
                                                         intra_dict={'gate_up':{'a': (self.corr_dict[(l,"gate_up_proj")][0] 
                                                                                    if ((l,"gate_up_proj") in self.corr_dict.keys())
                                                                                   else None), 
                                                                            'b': (self.corr_dict[(l,"gate_up_proj")][1] 
                                                                                    if ((l,"gate_up_proj") in self.corr_dict.keys())
                                                                                   else None),
                                                                             'targ': targ_dict[(l,"gate_up_proj")],
                                                                             'b_lh': self.b_dict[(l,"gate_up_proj")]}},
                                                         inter_class_dict={'qkv':(selectorbg_qkv_linreg if ((l+1,"qkv_proj") in self.corr_dict.keys() )
                                                                                   else selectorbg_qkv_gemv)},
                                                         inter_dict={'qkv':{'a': (self.corr_dict[(l+1,"qkv_proj")][0] 
                                                                                if ((l+1,"qkv_proj") in self.corr_dict.keys())
                                                                                   else None), 
                                                                          'b': (self.corr_dict[(l+1,"qkv_proj")][1] 
                                                                                if ((l+1,"qkv_proj") in self.corr_dict.keys())
                                                                                   else None),
                                                                             'targ': targ_dict[(l+1,"qkv_proj")],
                                                                             'b_lh': self.b_dict[(l+1,"qkv_proj")]}},

                                                         wo_class=selector_linreg if ((l, "o_proj") in self.corr_dict.keys())
                                                                    else selector_gemv,
                                                         wo_dict={'a': (self.corr_dict[(l,"o_proj")][0] 
                                                                        if ((l,"o_proj") in self.corr_dict.keys())
                                                                        else None), 
                                                                  'b': (self.corr_dict[(l,"o_proj")][1] 
                                                                        if ((l,"o_proj") in self.corr_dict.keys())
                                                                        else None),
                                                                  'targ': targ_dict[(l,"o_proj")],
                                                                  'b_lh': self.b_dict[(l,"o_proj")]},
                                                         w2_class=selector_linreg if ((l, "down_proj") in self.corr_dict.keys())
                                                                    else selector_gemv,
                                                         w2_dict={'a': (self.corr_dict[(l,"down_proj")][0] 
                                                                        if ((l,"down_proj") in self.corr_dict.keys())
                                                                        else None), 
                                                                  'b': (self.corr_dict[(l,"down_proj")][1] 
                                                                        if ((l,"down_proj") in self.corr_dict.keys())
                                                                        else None),
                                                                  'targ': targ_dict[(l,"down_proj")],
                                                                  'b_lh': self.b_dict[(l,"down_proj")]},
                                                         ) for l in range(config.n_layer))
            self.norm = Phi3RMSNorm(config.dim, eps=config.norm_eps)
        else:
            if "llama" in config.model_name.lower():
                self.corr_arr_path = f"config/llama-3-8b_corr_arr_0.9.pt"
                self.th_arr_path = f"config/llama-3-8b_th_arr.pt"
                self.th_layerbits_path = f"config/llama-3-8b_th_arr_layerbits.pt"
            elif "mistral" in config.model_name.lower():
                self.corr_arr_path = f"config/mistral-7b-v0.3_corr_arr_0.9.pt"
            elif "qwen" in config.model_name.lower():
                self.corr_arr_path = f"config/qwen2.5-14b_corr_arr_0.9.pt"
            self.corr_arr = torch.load(self.corr_arr_path)
            self.corr_dict = {}
            th_arr, max_mem_dict = torch.load(self.th_arr_path)
            bh_yes_arr= torch.load(self.th_layerbits_path)

            self.b_dict = {}
            for (l, n, slope, inter, _, b_l, b_h) in self.corr_arr:
                self.corr_dict[(l,n)] = (slope.item(), inter.item())
                
            self.th_dict = {}
            targ_dict = {}
            self.module_arr = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
            for l in range(config.n_layer):
                for n in self.module_arr:
                    th = th_arr[l*len(self.module_arr)+self.module_arr.index(n)]
                    bh_yes = bh_yes_arr[l*len(self.module_arr)+self.module_arr.index(n)]
                    maxmem = max_mem_dict[(l, n)]
                    if maxmem == 6:
                        if th < -0.5:
                            b_l = prec_arr[0]
                            b_h = prec_arr[1]
                            th = -th-0.5
                        elif th < 0.5:
                            b_l = prec_arr[1]
                            b_h = prec_arr[2]
                            th = -th+0.5
                        else:
                            b_l = prec_arr[2]
                            b_h = prec_arr[3]
                            th= -th+1.5
                    elif maxmem == 5:
                        if th < 0:
                            b_l = prec_arr[0]
                            b_h = prec_arr[1]
                            th = -th
                        else:
                            b_l = prec_arr[1]
                            b_h = prec_arr[2]
                            th = 1-th
                    elif maxmem == 4:
                        b_l = prec_arr[0]
                        b_h = prec_arr[1]
                        th = 1-th
                    elif maxmem == 3:
                        b_l = prec_arr[0]
                        b_h = prec_arr[0]
                    
                    th = torch.tensor(th, dtype=float, device="cuda")
                    self.th_dict[(l,n)] = th
                    self.b_dict[(l,n)] = (torch.tensor(b_l, dtype=torch.int, device="cuda"), 
                                        torch.tensor(b_h, dtype=torch.int, device="cuda"))
                    
                    targ_dict[(l,n)] = -1000000.0 if bh_yes == 1 else 1000000.0

            # Exception handling for last layer
            for l in [config.n_layer]:
                for n in self.module_arr:
                    targ_dict[(l,n)] = 0.5
                    self.b_dict[(l,n)] = (torch.tensor(-1, dtype=torch.int, device="cuda"), 
                                        torch.tensor(-1, dtype=torch.int, device="cuda"))

            self.layers = nn.ModuleList(TransformerBlock(config, linear_class, linear_kwargs, 
                                                         intra_class_dict={'gate':(selectorbg_g_linreg if ((l,"gate_proj") in self.corr_dict.keys() )
                                                                                   else selectorbg_g_gemv),
                                                                            'up':(selectorbg_u_linreg if ((l,"up_proj") in self.corr_dict.keys() )
                                                                                   else selectorbg_u_gemv)},
                                                         intra_dict={'gate':{'a': (self.corr_dict[(l,"gate_proj")][0] 
                                                                                    if ((l,"gate_proj") in self.corr_dict.keys())
                                                                                   else None), 
                                                                            'b': (self.corr_dict[(l,"gate_proj")][1] 
                                                                                    if ((l,"gate_proj") in self.corr_dict.keys())
                                                                                   else None),
                                                                             'targ': targ_dict[(l,"gate_proj")],
                                                                             'b_lh': self.b_dict[(l,"gate_proj")]},
                                                                     'up':{'a': (self.corr_dict[(l,"up_proj")][0] 
                                                                                    if ((l,"up_proj") in self.corr_dict.keys())
                                                                                   else None), 
                                                                            'b': (self.corr_dict[(l,"up_proj")][1] 
                                                                                    if ((l,"up_proj") in self.corr_dict.keys())
                                                                                   else None),
                                                                             'targ': targ_dict[(l,"up_proj")],
                                                                             'b_lh': self.b_dict[(l,"up_proj")]}},
                                                         inter_class_dict={'q':(selectorbg_q_linreg if ((l+1,"q_proj") in self.corr_dict.keys() )
                                                                                   else selectorbg_q_gemv),
                                                                            'k':(selectorbg_k_linreg if ((l+1,"k_proj") in self.corr_dict.keys() )
                                                                                   else selectorbg_k_gemv),
                                                                            'v':(selectorbg_v_linreg if ((l+1,"v_proj") in self.corr_dict.keys() )
                                                                                   else selectorbg_v_gemv)},
                                                         inter_dict={'q':{'a': (self.corr_dict[(l+1,"q_proj")][0] 
                                                                                if ((l+1,"q_proj") in self.corr_dict.keys())
                                                                                   else None), 
                                                                          'b': (self.corr_dict[(l+1,"q_proj")][1] 
                                                                                if ((l+1,"q_proj") in self.corr_dict.keys())
                                                                                   else None),
                                                                             'targ': targ_dict[(l+1,"q_proj")],
                                                                             'b_lh': self.b_dict[(l+1,"q_proj")]},
                                                                     'k':{'a': (self.corr_dict[(l+1,"k_proj")][0] 
                                                                                if ((l+1,"k_proj") in self.corr_dict.keys())
                                                                                   else None), 
                                                                          'b': (self.corr_dict[(l+1,"k_proj")][1] 
                                                                                if ((l+1,"k_proj") in self.corr_dict.keys())
                                                                                   else None),
                                                                             'targ': targ_dict[(l+1,"k_proj")],
                                                                             'b_lh': self.b_dict[(l+1,"k_proj")]},
                                                                     'v':{'a': (self.corr_dict[(l+1,"v_proj")][0] 
                                                                                if ((l+1,"v_proj") in self.corr_dict.keys())
                                                                                   else None), 
                                                                          'b': (self.corr_dict[(l+1,"v_proj")][1] 
                                                                                if ((l+1,"v_proj") in self.corr_dict.keys())
                                                                                   else None),
                                                                             'targ': targ_dict[(l+1,"v_proj")],
                                                                             'b_lh': self.b_dict[(l+1,"v_proj")]}},

                                                         wo_class=selector_linreg if ((l, "o_proj") in self.corr_dict.keys())
                                                                    else selector_gemv,
                                                         wo_dict={'a': (self.corr_dict[(l,"o_proj")][0] 
                                                                        if ((l,"o_proj") in self.corr_dict.keys())
                                                                        else None), 
                                                                  'b': (self.corr_dict[(l,"o_proj")][1] 
                                                                        if ((l,"o_proj") in self.corr_dict.keys())
                                                                        else None),
                                                                  'targ': targ_dict[(l,"o_proj")],
                                                                  'b_lh': self.b_dict[(l,"o_proj")]},
                                                         w2_class=selector_linreg if ((l, "down_proj") in self.corr_dict.keys())
                                                                    else selector_gemv,
                                                         w2_dict={'a': (self.corr_dict[(l,"down_proj")][0] 
                                                                        if ((l,"down_proj") in self.corr_dict.keys())
                                                                        else None), 
                                                                  'b': (self.corr_dict[(l,"down_proj")][1] 
                                                                        if ((l,"down_proj") in self.corr_dict.keys())
                                                                        else None),
                                                                  'targ': targ_dict[(l,"down_proj")],
                                                                  'b_lh': self.b_dict[(l,"down_proj")]},
                                                         ) for l in range(config.n_layer))
            
            if "qwen" in config.model_name.lower():
                self.norm = Phi3RMSNorm(config.dim, eps=config.norm_eps)
            else:
                self.norm = RMSNorm(config.dim, eps=config.norm_eps)
            

        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
        if "phi" in config.model_name.lower():
            for i, block in enumerate(self.layers):
                if i == 0:
                    self.layers[i] = TransformerBlockFirst_fused(config, linear_class, linear_kwargs, 
                                                            intra_class_dict={'gate_up':(selectorbg_gu_linreg if ((i,"gate_up_proj") in self.corr_dict.keys() )
                                                                                    else selectorbg_gu_gemv)},
                                                            intra_dict={'gate_up':{'a': (self.corr_dict[(i,"gate_up_proj")][0] 
                                                                                        if ((i,"gate_up_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                                'b': (self.corr_dict[(i,"gate_up_proj")][1] 
                                                                                        if ((i,"gate_up_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                                'targ': targ_dict[(i,"gate_up_proj")],
                                                                                'b_lh': self.b_dict[(i,"gate_up_proj")]}},
                                                            inter_class_dict={'qkv':(selectorbg_qkv_linreg if ((i+1,"qkv_proj") in self.corr_dict.keys() )
                                                                                    else selectorbg_qkv_gemv)},
                                                            inter_dict={'qkv':{'a': (self.corr_dict[(i+1,"qkv_proj")][0] 
                                                                                    if ((i+1,"qkv_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                            'b': (self.corr_dict[(i+1,"qkv_proj")][1] 
                                                                                    if ((i+1,"qkv_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                                'targ': targ_dict[(i+1,"qkv_proj")],
                                                                                'b_lh': self.b_dict[(i+1,"qkv_proj")]}},
                                                            trigger_class_dict={'qkv':(trigger_linreg if ((i,"qkv_proj") in self.corr_dict.keys() )
                                                                                    else trigger_gemv)},
                                                            trigger_dict={'qkv':{'a': (self.corr_dict[(i,"qkv_proj")][0] 
                                                                                    if ((i,"qkv_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                            'b': (self.corr_dict[(i,"qkv_proj")][1] 
                                                                                    if ((i,"qkv_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                            'targ': targ_dict[(i,"qkv_proj")],
                                                                            'b_lh': self.b_dict[(i,"qkv_proj")]}},
                                                            wo_class=selector_linreg if ((i, "o_proj") in self.corr_dict.keys())
                                                                        else selector_gemv,
                                                            wo_dict={'a': (self.corr_dict[(i,"o_proj")][0] 
                                                                            if ((i,"o_proj") in self.corr_dict.keys())
                                                                            else None), 
                                                                    'b': (self.corr_dict[(i,"o_proj")][1] 
                                                                            if ((i,"o_proj") in self.corr_dict.keys())
                                                                            else None),
                                                                    'targ': targ_dict[(i,"o_proj")],
                                                                    'b_lh': self.b_dict[(i,"o_proj")]},
                                                            w2_class=selector_linreg if ((i, "down_proj") in self.corr_dict.keys())
                                                                        else selector_gemv,
                                                            w2_dict={'a': (self.corr_dict[(i,"down_proj")][0] 
                                                                            if ((i,"down_proj") in self.corr_dict.keys())
                                                                            else None), 
                                                                    'b': (self.corr_dict[(i,"down_proj")][1] 
                                                                            if ((i,"down_proj") in self.corr_dict.keys())
                                                                            else None),
                                                                    'targ': targ_dict[(i,"down_proj")],
                                                                    'b_lh': self.b_dict[(i,"down_proj")]})
                elif i == len(self.layers)-1:
                    self.layers[i] = TransformerBlockLast_fused(config, linear_class, linear_kwargs, 
                                                            intra_class_dict={'gate_up':(selectorbg_gu_linreg if ((i,"gate_up_proj") in self.corr_dict.keys() )
                                                                                    else selectorbg_gu_gemv)},
                                                            intra_dict={'gate_up':{'a': (self.corr_dict[(i,"gate_up_proj")][0] 
                                                                                        if ((i,"gate_up_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                                'b': (self.corr_dict[(i,"gate_up_proj")][1] 
                                                                                        if ((i,"gate_up_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                                'targ': targ_dict[(i,"gate_up_proj")],
                                                                                'b_lh': self.b_dict[(i,"gate_up_proj")]}},
                                                            wo_class=selector_linreg if ((i, "o_proj") in self.corr_dict.keys())
                                                                        else selector_gemv,
                                                            wo_dict={'a': (self.corr_dict[(i,"o_proj")][0] 
                                                                            if ((i,"o_proj") in self.corr_dict.keys())
                                                                            else None), 
                                                                    'b': (self.corr_dict[(i,"o_proj")][1] 
                                                                            if ((i,"o_proj") in self.corr_dict.keys())
                                                                            else None),
                                                                    'targ': targ_dict[(i,"o_proj")],
                                                                    'b_lh': self.b_dict[(i,"o_proj")]},
                                                            w2_class=selector_linreg if ((i, "down_proj") in self.corr_dict.keys())
                                                                        else selector_gemv,
                                                            w2_dict={'a': (self.corr_dict[(i,"down_proj")][0] 
                                                                            if ((i,"down_proj") in self.corr_dict.keys())
                                                                            else None), 
                                                                    'b': (self.corr_dict[(i,"down_proj")][1] 
                                                                            if ((i,"down_proj") in self.corr_dict.keys())
                                                                            else None),
                                                                    'targ': targ_dict[(i,"down_proj")],
                                                                    'b_lh': self.b_dict[(i,"down_proj")]})
                    block = self.layers[i]
                
                if i > 0:
                    block.prev_inter_sne = self.layers[i-1].inter_sne
                    block.prev_inter_bsel = self.layers[i-1].inter_bsel
        else:
            for i, block in enumerate(self.layers):
                if i == 0:
                    self.layers[i] = TransformerBlockFirst(config, linear_class, linear_kwargs, 
                                                            intra_class_dict={'gate':(selectorbg_g_linreg if ((i,"gate_proj") in self.corr_dict.keys() )
                                                                                    else selectorbg_g_gemv),
                                                                                'up':(selectorbg_u_linreg if ((i,"up_proj") in self.corr_dict.keys() )
                                                                                    else selectorbg_u_gemv)},
                                                            intra_dict={'gate':{'a': (self.corr_dict[(i,"gate_proj")][0] 
                                                                                        if ((i,"gate_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                                'b': (self.corr_dict[(i,"gate_proj")][1] 
                                                                                        if ((i,"gate_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                                'targ': targ_dict[(i,"gate_proj")],
                                                                                'b_lh': self.b_dict[(i,"gate_proj")]},
                                                                        'up':{'a': (self.corr_dict[(i,"up_proj")][0] 
                                                                                        if ((i,"up_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                                'b': (self.corr_dict[(i,"up_proj")][1] 
                                                                                        if ((i,"up_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                                'targ': targ_dict[(i,"up_proj")],
                                                                                'b_lh': self.b_dict[(i,"up_proj")]}},
                                                            inter_class_dict={'q':(selectorbg_q_linreg if ((i+1,"q_proj") in self.corr_dict.keys() )
                                                                                    else selectorbg_q_gemv),
                                                                                'k':(selectorbg_k_linreg if ((i+1,"k_proj") in self.corr_dict.keys() )
                                                                                    else selectorbg_k_gemv),
                                                                                'v':(selectorbg_v_linreg if ((i+1,"v_proj") in self.corr_dict.keys() )
                                                                                    else selectorbg_v_gemv)},
                                                            inter_dict={'q':{'a': (self.corr_dict[(i+1,"q_proj")][0] 
                                                                                    if ((i+1,"q_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                            'b': (self.corr_dict[(i+1,"q_proj")][1] 
                                                                                    if ((i+1,"q_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                                'targ': targ_dict[(i+1,"q_proj")],
                                                                                'b_lh': self.b_dict[(i+1,"q_proj")]},
                                                                        'k':{'a': (self.corr_dict[(i+1,"k_proj")][0] 
                                                                                    if ((i+1,"k_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                            'b': (self.corr_dict[(i+1,"k_proj")][1] 
                                                                                    if ((i+1,"k_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                                'targ': targ_dict[(i+1,"k_proj")],
                                                                                'b_lh': self.b_dict[(i+1,"k_proj")]},
                                                                        'v':{'a': (self.corr_dict[(i+1,"v_proj")][0] 
                                                                                    if ((i+1,"v_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                            'b': (self.corr_dict[(i+1,"v_proj")][1] 
                                                                                    if ((i+1,"v_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                                'targ': targ_dict[(i+1,"v_proj")],
                                                                                'b_lh': self.b_dict[(i+1,"v_proj")]}},
                                                            trigger_class_dict={'q':(trigger_linreg if ((i,"q_proj") in self.corr_dict.keys() )
                                                                                    else trigger_gemv),
                                                                                'k':(trigger_linreg if ((i,"k_proj") in self.corr_dict.keys() )
                                                                                    else trigger_gemv),
                                                                                'v':(trigger_linreg if ((i,"v_proj") in self.corr_dict.keys() )
                                                                                    else trigger_gemv)},
                                                            trigger_dict={'q':{'a': (self.corr_dict[(i,"q_proj")][0] 
                                                                                    if ((i,"q_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                            'b': (self.corr_dict[(i,"q_proj")][1] 
                                                                                    if ((i,"q_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                            'targ': targ_dict[(i,"q_proj")],
                                                                            'b_lh': self.b_dict[(i,"q_proj")]},
                                                                        'k':{'a': (self.corr_dict[(i,"k_proj")][0] 
                                                                                    if ((i,"k_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                            'b': (self.corr_dict[(i,"k_proj")][1] 
                                                                                    if ((i,"k_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                            'targ': targ_dict[(i,"k_proj")],
                                                                            'b_lh': self.b_dict[(i,"k_proj")]},
                                                                        'v':{'a': (self.corr_dict[(i,"v_proj")][0] 
                                                                                    if ((i,"v_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                            'b': (self.corr_dict[(i,"v_proj")][1] 
                                                                                    if ((i,"v_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                            'targ': targ_dict[(i,"v_proj")],
                                                                            'b_lh': self.b_dict[(i,"v_proj")]}},
                                                            wo_class=selector_linreg if ((i, "o_proj") in self.corr_dict.keys())
                                                                        else selector_gemv,
                                                            wo_dict={'a': (self.corr_dict[(i,"o_proj")][0] 
                                                                            if ((i,"o_proj") in self.corr_dict.keys())
                                                                            else None), 
                                                                    'b': (self.corr_dict[(i,"o_proj")][1] 
                                                                            if ((i,"o_proj") in self.corr_dict.keys())
                                                                            else None),
                                                                    'targ': targ_dict[(i,"o_proj")],
                                                                    'b_lh': self.b_dict[(i,"o_proj")]},
                                                            w2_class=selector_linreg if ((i, "down_proj") in self.corr_dict.keys())
                                                                        else selector_gemv,
                                                            w2_dict={'a': (self.corr_dict[(i,"down_proj")][0] 
                                                                            if ((i,"down_proj") in self.corr_dict.keys())
                                                                            else None), 
                                                                    'b': (self.corr_dict[(i,"down_proj")][1] 
                                                                            if ((i,"down_proj") in self.corr_dict.keys())
                                                                            else None),
                                                                    'targ': targ_dict[(i,"down_proj")],
                                                                    'b_lh': self.b_dict[(i,"down_proj")]})
                elif i == len(self.layers)-1:
                    self.layers[i] = TransformerBlockLast(config, linear_class, linear_kwargs, 
                                                            intra_class_dict={'gate':(selectorbg_g_linreg if ((i,"gate_proj") in self.corr_dict.keys() )
                                                                                    else selectorbg_g_gemv),
                                                                                'up':(selectorbg_u_linreg if ((i,"up_proj") in self.corr_dict.keys() )
                                                                                    else selectorbg_u_gemv)},
                                                            intra_dict={'gate':{'a': (self.corr_dict[(i,"gate_proj")][0] 
                                                                                        if ((i,"gate_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                                'b': (self.corr_dict[(i,"gate_proj")][1] 
                                                                                        if ((i,"gate_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                                'targ': targ_dict[(i,"gate_proj")],
                                                                                'b_lh': self.b_dict[(i,"gate_proj")]},
                                                                        'up':{'a': (self.corr_dict[(i,"up_proj")][0] 
                                                                                        if ((i,"up_proj") in self.corr_dict.keys())
                                                                                    else None), 
                                                                                'b': (self.corr_dict[(i,"up_proj")][1] 
                                                                                        if ((i,"up_proj") in self.corr_dict.keys())
                                                                                    else None),
                                                                                'targ': targ_dict[(i,"up_proj")],
                                                                                'b_lh': self.b_dict[(i,"up_proj")]}},
                                                            wo_class=selector_linreg if ((i, "o_proj") in self.corr_dict.keys())
                                                                        else selector_gemv,
                                                            wo_dict={'a': (self.corr_dict[(i,"o_proj")][0] 
                                                                            if ((i,"o_proj") in self.corr_dict.keys())
                                                                            else None), 
                                                                    'b': (self.corr_dict[(i,"o_proj")][1] 
                                                                            if ((i,"o_proj") in self.corr_dict.keys())
                                                                            else None),
                                                                    'targ': targ_dict[(i,"o_proj")],
                                                                    'b_lh': self.b_dict[(i,"o_proj")]},
                                                            w2_class=selector_linreg if ((i, "down_proj") in self.corr_dict.keys())
                                                                        else selector_gemv,
                                                            w2_dict={'a': (self.corr_dict[(i,"down_proj")][0] 
                                                                            if ((i,"down_proj") in self.corr_dict.keys())
                                                                            else None), 
                                                                    'b': (self.corr_dict[(i,"down_proj")][1] 
                                                                            if ((i,"down_proj") in self.corr_dict.keys())
                                                                            else None),
                                                                    'targ': targ_dict[(i,"down_proj")],
                                                                    'b_lh': self.b_dict[(i,"down_proj")]})
                    block = self.layers[i]
                
                if i > 0:
                    block.prev_inter_sne = self.layers[i-1].inter_sne
                    block.prev_inter_bsel1 = self.layers[i-1].inter_bsel1
                    block.prev_inter_bsel2 = self.layers[i-1].inter_bsel2
                    block.prev_inter_bsel3 = self.layers[i-1].inter_bsel3


        self.freqs_cis: Optional[Tensor] = None
        self.mask_cache: Optional[Tensor] = None
        self.max_batch_size = -1
        self.max_seq_length = -1

    def setup_caches(self, max_batch_size, max_seq_length):
        if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
            return
        head_dim = self.config.dim // self.config.n_head
        max_seq_length = find_multiple(max_seq_length, 8)
        self.max_seq_length = max_seq_length
        self.max_batch_size = max_batch_size
        dtype = self.output.weight.dtype
        # For quantized layers, dtype is encoded in scales
        if hasattr(self.output, "scales"):
            dtype = self.output.scales.dtype
        elif hasattr(self.output, "scales_and_zeros"):
            dtype = self.output.scales_and_zeros.dtype
        for b in self.layers:
            b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)

        self.freqs_cis = precompute_freqs_cis(self.config.block_size, head_dim, self.config.rope_base, dtype, self.config.rope_scaling)
        self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

    def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
        assert self.freqs_cis is not None, "Caches must be initialized first"
        mask = self.causal_mask[None, None, input_pos]
        freqs_cis = self.freqs_cis[input_pos]
        x = self.tok_embeddings(idx)
        
        # fakeTrigger(self.trigger_sne)

        for i, layer in enumerate(self.layers):
            x = layer(x, input_pos, freqs_cis, mask)
        x = self.norm(x)
        logits = self.output(x)
        return logits

    def load_dec_data(self, dec_data):
        for idx, layer in enumerate(self.layers):
            prefix = f"layers.{idx}."
            layer.attention.wqkv.load_dec_data(
                                dec_data[prefix+"attention.wqkv.q_residual"],
                                dec_data[prefix+"attention.wqkv.scales"],
                                dec_data[prefix+"attention.wqkv.thresholds"]
                                )
            layer.attention.wo.load_dec_data(
                                dec_data[prefix+"attention.wo.q_residual"],
                                dec_data[prefix+"attention.wo.scales"],
                                dec_data[prefix+"attention.wo.thresholds"]
                                )
            layer.feed_forward.w1w3.load_dec_data(
                                dec_data[prefix+"feed_forward.w1w3.q_residual"],
                                dec_data[prefix+"feed_forward.w1w3.scales"],
                                dec_data[prefix+"feed_forward.w1w3.thresholds"]
                                )
            layer.feed_forward.w2.load_dec_data(
                                dec_data[prefix+"feed_forward.w2.q_residual"],
                                dec_data[prefix+"feed_forward.w2.scales"],
                                dec_data[prefix+"feed_forward.w2.thresholds"]
                                )

    def create_dec_context(self, num_tb, buffer_size = 1024):
        self.selected_rows_buffer = torch.empty(buffer_size, dtype=torch.int, device='cuda')
        self.selected_activations_buffer = torch.empty(buffer_size, dtype=self.dtype, device='cuda')
        self.dec_context = create_dec_context(num_tb, self.selected_rows_buffer, self.selected_activations_buffer)

    def update_dec_context(self, num_tb):
        self.dec_context = create_dec_context(num_tb, self.selected_rows_buffer, self.selected_activations_buffer)

    def set_dec_config(self, k_chunk_list):
        # k_per_chunk_list: [qkv, o, gate/up, down]
        assert len(k_chunk_list) == 4
        for layer in self.layers:
            layer.attention.wqkv.create_dec_config(self.dec_context, k_chunk_list[0])
            layer.attention.wo.create_dec_config(self.dec_context, k_chunk_list[1])
            layer.feed_forward.w1w3.create_dec_config(self.dec_context, k_chunk_list[2])
            layer.feed_forward.w2.create_dec_config(self.dec_context, k_chunk_list[3])

    def update_dec_config(self, k_chunk_list):
        assert len(k_chunk_list) == 4
        for layer in self.layers:
            layer.attention.wqkv.update_dec_config(self.dec_context, k_chunk_list[0])
            layer.attention.wo.update_dec_config(self.dec_context, k_chunk_list[1])
            layer.feed_forward.w1w3.update_dec_config(self.dec_context, k_chunk_list[2])
            layer.feed_forward.w2.update_dec_config(self.dec_context, k_chunk_list[3])

    @classmethod
    def from_name(cls, dtype, name: str, linear_class=nn.Linear, linear_kwargs=None, halve_layers=False) -> "Transformer":
        return cls(dtype, ModelArgs.from_name(name), linear_class=linear_class, linear_kwargs=linear_kwargs, halve_layers=halve_layers)


class selector():
    def __init__(self):
        pass
    def compare(self, y):
        pass

class selector_linreg(selector):
    def __init__(self, config:dict, sne, b_l, b_h):
        self.targ = config["targ"]
        self.a = config["a"]
        self.b = config["b"]
        self.b_l = b_l
        self.b_h = b_h
        self.b_d = b_h-b_l
        self.sne = sne
    
    def compare(self, y):
        bsel = (((y.norm() * self.a + self.b)>self.targ)*self.b_d+self.b_l).to(torch.int)
        fakeTrigger(self.sne)
        return bsel

class selector_gemv(selector):
    def __init__(self, config:dict, sne, b_l, b_h):
        self.targ = config["targ"]
        self.jl = config["jl"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
        self.b_d = b_h-b_l
    
    def compare(self, y):
        bsel = (((y @ self.jl.T).norm() >self.targ)*self.b_d+self.b_l).to(torch.int)
        fakeTrigger(self.sne)
        return bsel
    
class selectorbg_q_gemv():
    def __init__(self, config, sne, b_l, b_h):
        self.jl = config["jl"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        gemvNormTHq(x, self.jl, self.bsel, self.b_l, self.b_h, self.targ, self.sne)

class selectorbg_k_gemv():
    def __init__(self, config, sne, b_l, b_h):
        self.jl = config["jl"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        gemvNormTHk(x, self.jl, self.bsel, self.b_l, self.b_h, self.targ, self.sne)

class selectorbg_v_gemv():
    def __init__(self, config, sne, b_l, b_h):
        self.jl = config["jl"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        gemvNormTHv(x, self.jl, self.bsel, self.b_l, self.b_h, self.targ, self.sne)

class selectorbg_g_gemv():
    def __init__(self, config, sne, b_l, b_h):
        self.jl = config["jl"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        gemvNormTHg(x, self.jl, self.bsel, self.b_l, self.b_h, self.targ, self.sne)

class selectorbg_u_gemv():
    def __init__(self, config, sne, b_l, b_h):
        self.jl = config["jl"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        gemvNormTHu(x, self.jl, self.bsel, self.b_l, self.b_h, self.targ, self.sne)

class selectorbg_qkv_gemv():
    def __init__(self, config, sne, b_l, b_h):
        self.jl = config["jl"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        gemvNormTHqkv(x, self.jl, self.bsel, self.b_l, self.b_h, self.targ, self.sne)

class selectorbg_gu_gemv():
    def __init__(self, config, sne, b_l, b_h):
        self.jl = config["jl"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        gemvNormTHgu(x, self.jl, self.bsel, self.b_l, self.b_h, self.targ, self.sne)

class selectorbg_q_linreg():
    def __init__(self, config, sne, b_l, b_h):
        self.a = config["a"]
        self.b = config["b"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        normTHq(x, self.a, self.b, self.targ, self.bsel, self.b_l, self.b_h, self.sne)

class selectorbg_k_linreg():
    def __init__(self, config, sne, b_l, b_h):
        self.a = config["a"]
        self.b = config["b"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        normTHk(x, self.a, self.b, self.targ, self.bsel, self.b_l, self.b_h, self.sne)

class selectorbg_v_linreg():
    def __init__(self, config, sne, b_l, b_h):
        self.a = config["a"]
        self.b = config["b"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        normTHv(x, self.a, self.b, self.targ, self.bsel, self.b_l, self.b_h, self.sne)

class selectorbg_g_linreg():
    def __init__(self, config, sne, b_l, b_h):
        self.a = config["a"]
        self.b = config["b"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        normTHg(x, self.a, self.b, self.targ, self.bsel, self.b_l, self.b_h, self.sne)

class selectorbg_u_linreg():
    def __init__(self, config, sne, b_l, b_h):
        self.a = config["a"]
        self.b = config["b"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        normTHu(x, self.a, self.b, self.targ, self.bsel, self.b_l, self.b_h, self.sne)

class selectorbg_qkv_linreg():
    def __init__(self, config, sne, b_l, b_h):
        self.a = config["a"]
        self.b = config["b"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        normTHqkv(x, self.a, self.b, self.targ, self.bsel, self.b_l, self.b_h, self.sne)

class selectorbg_gu_linreg():
    def __init__(self, config, sne, b_l, b_h):
        self.a = config["a"]
        self.b = config["b"]
        self.bsel = config["bsel"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        normTHgu(x, self.a, self.b, self.targ, self.bsel, self.b_l, self.b_h, self.sne)

class trigger_gemv():
    def __init__(self, config, sne, b_l, b_h):
        self.jl = config["jl"]
        self.targ = config["targ"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    def compare(self, x):
        bsel = (((x @ self.jl.T).norm() >self.targ)*(self.b_h - self.b_l)+self.b_l).to(torch.int)
        return bsel

class trigger_linreg():
    def __init__(self, config, sne, b_l, b_h):
        self.targ = config["targ"]
        self.a = config["a"]
        self.b = config["b"]
        self.sne = sne
        self.b_l = b_l
        self.b_h = b_h
    
    def compare(self, x):
        bsel = (((x.norm() * self.a + self.b)>self.targ)*(self.b_h - self.b_l)+self.b_l).to(torch.int)
        return bsel




class TransformerBlockFirst(nn.Module):
    def __init__(self, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None,
                intra_class_dict={}, intra_dict={}, 
                inter_class_dict={}, inter_dict={}, 
                trigger_class_dict={}, trigger_dict={},
                wo_class=selector_gemv, wo_dict={},
                w2_class=selector_gemv, w2_dict={},
                ) -> None:
        super().__init__()
        self.attention = Attention(config, linear_class, linear_kwargs, wo_class, wo_dict)
        self.feed_forward = FeedForward(config, linear_class, linear_kwargs, w2_class, w2_dict)

        # == My calculations from first residual to FFN ==
        self.intra_sne = create_streamNevent_full()

        # Setup for gate
        gate_config = {}
        self.intra_bsel1 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.intra_targ1 = intra_dict['gate']['targ']
        gate_config['bsel'] = self.intra_bsel1
        gate_config['targ'] = self.intra_targ1
        if intra_class_dict['gate'] == selectorbg_g_gemv:
            self.intra_jl1 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            gate_config['jl'] = self.intra_jl1
        elif intra_class_dict['gate'] == selectorbg_g_linreg:
            self.intra_a1 = intra_dict['gate']['a']
            self.intra_b1 = intra_dict['gate']['b']
            gate_config['a'] = self.intra_a1
            gate_config['b'] = self.intra_b1

        # Instantiate gate selector
        self.gate_selector = intra_class_dict['gate'](gate_config, self.intra_sne, intra_dict['gate']['b_lh'][0], intra_dict['gate']['b_lh'][1])

        # Setup for up
        up_config = {}
        self.intra_bsel2 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.intra_targ2 = intra_dict['up']['targ']
        up_config['bsel'] = self.intra_bsel2
        up_config['targ'] = self.intra_targ2
        if intra_class_dict['up'] == selectorbg_u_gemv:
            self.intra_jl2 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            up_config['jl'] = self.intra_jl2
        elif intra_class_dict['up'] == selectorbg_u_linreg:
            self.intra_a2 = intra_dict['up']['a']
            self.intra_b2 = intra_dict['up']['b']
            up_config['a'] = self.intra_a2
            up_config['b'] = self.intra_b2

        # Instantiate up selector
        self.up_selector = intra_class_dict['up'](up_config, self.intra_sne, intra_dict['up']['b_lh'][0], intra_dict['up']['b_lh'][1])
        # ================================================


        # == Calculations for next layer's attn module ==
        self.inter_sne = create_streamNevent_full()

        # Setup for q
        q_config = {}
        self.inter_bsel1 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.inter_targ1 = inter_dict['q']['targ']
        q_config['bsel'] = self.inter_bsel1
        q_config['targ'] = self.inter_targ1
        if inter_class_dict['q'] == selectorbg_q_gemv:
            self.inter_jl1 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            q_config['jl'] = self.inter_jl1
        elif inter_class_dict['q'] == selectorbg_q_linreg:
            self.inter_a1 = inter_dict['q']['a']
            self.inter_b1 = inter_dict['q']['b']
            q_config['a'] = self.inter_a1
            q_config['b'] = self.inter_b1

        # Instantiate q selector
        self.q_selector = inter_class_dict['q'](q_config, self.inter_sne, inter_dict['q']['b_lh'][0], inter_dict['q']['b_lh'][1])

        # Setup for k
        k_config = {}
        self.inter_bsel2 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.inter_targ2 = inter_dict['k']['targ']
        k_config['bsel'] = self.inter_bsel2
        k_config['targ'] = self.inter_targ2
        if inter_class_dict['k'] == selectorbg_k_gemv:
            self.inter_jl2 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            k_config['jl'] = self.inter_jl2
        elif inter_class_dict['k'] == selectorbg_k_linreg:
            self.inter_a2 = inter_dict['k']['a']
            self.inter_b2 = inter_dict['k']['b']
            k_config['a'] = self.inter_a2
            k_config['b'] = self.inter_b2

        # Instantiate k selector
        self.k_selector = inter_class_dict['k'](k_config, self.inter_sne, inter_dict['k']['b_lh'][0], inter_dict['k']['b_lh'][1])

        # Setup for v
        v_config = {}
        self.inter_bsel3 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.inter_targ3 = inter_dict['v']['targ']
        v_config['bsel'] = self.inter_bsel3
        v_config['targ'] = self.inter_targ3
        if inter_class_dict['v'] == selectorbg_v_gemv:
            self.inter_jl3 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            v_config['jl'] = self.inter_jl3
        elif inter_class_dict['v'] == selectorbg_v_linreg:
            self.inter_a3 = inter_dict['v']['a']
            self.inter_b3 = inter_dict['v']['b']
            v_config['a'] = self.inter_a3
            v_config['b'] = self.inter_b3

        # Instantiate v selector
        self.v_selector = inter_class_dict['v'](v_config, self.inter_sne, inter_dict['v']['b_lh'][0], inter_dict['v']['b_lh'][1])

        # ================================================

        # == Triggers ==
        self.trigger_sne = create_streamNevent_full()

        # Setup for trigger q
        q_config = {}
        self.trigger_targ1 = trigger_dict['q']['targ']
        q_config['targ'] = self.trigger_targ1
        if trigger_class_dict['q'] == trigger_gemv:
            self.trigger_jl1 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            q_config['jl'] = self.trigger_jl1
        elif trigger_class_dict['q'] == trigger_linreg:
            self.trigger_a1 = trigger_dict['q']['a']
            self.trigger_b1 = trigger_dict['q']['b']
            q_config['a'] = self.trigger_a1
            q_config['b'] = self.trigger_b1

        # Instantiate trigger q
        self.trigger_q = trigger_class_dict['q'](q_config, self.trigger_sne, trigger_dict['q']['b_lh'][0], trigger_dict['q']['b_lh'][1])

        # Setup for trigger k
        k_config = {}
        self.trigger_targ2 = trigger_dict['k']['targ']
        k_config['targ'] = self.trigger_targ2
        if trigger_class_dict['k'] == trigger_gemv:
            self.trigger_jl2 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            k_config['jl'] = self.trigger_jl2
        elif trigger_class_dict['k'] == trigger_linreg:
            self.trigger_a2 = trigger_dict['k']['a']
            self.trigger_b2 = trigger_dict['k']['b']
            k_config['a'] = self.trigger_a2
            k_config['b'] = self.trigger_b2

        # Instantiate trigger k
        self.trigger_k = trigger_class_dict['k'](k_config, self.trigger_sne, trigger_dict['k']['b_lh'][0], trigger_dict['k']['b_lh'][1])

        # Setup for trigger v
        v_config = {}
        self.trigger_targ3 = trigger_dict['v']['targ']
        v_config['targ'] = self.trigger_targ3
        if trigger_class_dict['v'] == trigger_gemv:
            self.trigger_jl3 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            v_config['jl'] = self.trigger_jl3
        elif trigger_class_dict['v'] == trigger_linreg:
            self.trigger_a3 = trigger_dict['v']['a']
            self.trigger_b3 = trigger_dict['v']['b']
            v_config['a'] = self.trigger_a3
            v_config['b'] = self.trigger_b3

        # Instantiate trigger v
        self.trigger_v = trigger_class_dict['v'](v_config, self.trigger_sne, trigger_dict['v']['b_lh'][0], trigger_dict['v']['b_lh'][1])

        # =======================

        if "llama" in config.model_name.lower() or "mistral" in config.model_name.lower():
            self.input_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None
        elif "phi" in config.model_name.lower() or "qwen" in config.model_name.lower():
            self.input_layernorm = Phi3RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = Phi3RMSNorm_dual(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None

    def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
        # gemvNormTH3Full(x, self.trigger_jl1, self.trigger_jl2, self.trigger_jl3,
        #                 self.trigger_res1, self.trigger_res2, self.trigger_res3, 
        #                 self.trigger_bsel1, self.trigger_bsel2, self.trigger_bsel3,
        #                 self.trigger_targ1, self.trigger_targ2, self.trigger_targ3,
        #                 self.trigger_sne)
        
        # gemvNormTH(x, self.jl, self.intra_bsel, self.targ, self.intra_sne)
        lnx = self.input_layernorm(x)

        trigger_bsel1 = self.trigger_q.compare(lnx[1])
        trigger_bsel2 = self.trigger_k.compare(lnx[1])
        trigger_bsel3 = self.trigger_v.compare(lnx[1])
        fakeTrigger(self.trigger_sne)

        # gemvNormTH2(lnx[1], self.intra_jl1, self.intra_jl2, 
        #             self.intra_bsel1, self.intra_bsel2, 
        #             self.intra_targ1, self.intra_targ2, self.intra_sne)
        self.gate_selector.compare(lnx[1])
        self.up_selector.compare(lnx[1])

        h = x + self.attention(lnx[0], 
                                freqs_cis, mask, 
                                trigger_bsel1, trigger_bsel2, trigger_bsel3,
                                self.trigger_sne, 
                                input_pos
                                )

        if self.pre_feedforward_layernorm != None:
            h = self.pre_feedforward_layernorm(h)
        
        ffn_in = self.post_attention_layernorm(h)
        # gemvNormTH3(ffn_in[1], self.inter_jl1, self.inter_jl2, self.inter_jl3,
        #             self.inter_bsel1, self.inter_bsel2, self.inter_bsel3,
        #             self.inter_targ1, self.inter_targ2, self.inter_targ3, self.inter_sne)
        self.q_selector.compare(ffn_in[1])
        self.k_selector.compare(ffn_in[1])
        self.v_selector.compare(ffn_in[1])
        out = self.feed_forward(ffn_in[0], self.intra_bsel1, self.intra_bsel2, self.intra_sne)

        if self.post_feedforward_layernorm != None:
            out = self.post_feedforward_layernorm(out)

        out = h + out 

        return out

class TransformerBlock(nn.Module):
    def __init__(self, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None,
                intra_class_dict={}, intra_dict={}, 
                inter_class_dict={}, inter_dict={},
                wo_class=selector_gemv, wo_dict={},
                w2_class=selector_gemv, w2_dict={}) -> None:
        super().__init__()
        self.attention = Attention(config, linear_class, linear_kwargs, wo_class, wo_dict)
        self.feed_forward = FeedForward(config, linear_class, linear_kwargs, w2_class, w2_dict)

        # == My calculations from first residual to FFN ==
        self.intra_sne = create_streamNevent_full()

        # Setup for gate
        gate_config = {}
        self.intra_bsel1 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.intra_targ1 = intra_dict['gate']['targ']
        gate_config['bsel'] = self.intra_bsel1
        gate_config['targ'] = self.intra_targ1
        if intra_class_dict['gate'] == selectorbg_g_gemv:
            self.intra_jl1 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            gate_config['jl'] = self.intra_jl1
        elif intra_class_dict['gate'] == selectorbg_g_linreg:
            self.intra_a1 = intra_dict['gate']['a']
            self.intra_b1 = intra_dict['gate']['b']
            gate_config['a'] = self.intra_a1
            gate_config['b'] = self.intra_b1

        # Instantiate gate selector
        self.gate_selector = intra_class_dict['gate'](gate_config, self.intra_sne, intra_dict['gate']['b_lh'][0], intra_dict['gate']['b_lh'][1])

        # Setup for up
        up_config = {}
        self.intra_bsel2 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.intra_targ2 = intra_dict['up']['targ']
        up_config['bsel'] = self.intra_bsel2
        up_config['targ'] = self.intra_targ2
        if intra_class_dict['up'] == selectorbg_u_gemv:
            self.intra_jl2 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            up_config['jl'] = self.intra_jl2
        elif intra_class_dict['up'] == selectorbg_u_linreg:
            self.intra_a2 = intra_dict['up']['a']
            self.intra_b2 = intra_dict['up']['b']
            up_config['a'] = self.intra_a2
            up_config['b'] = self.intra_b2

        # Instantiate up selector
        self.up_selector = intra_class_dict['up'](up_config, self.intra_sne, intra_dict['up']['b_lh'][0], intra_dict['up']['b_lh'][1])
        # ================================================


        # == Calculations for next layer's attn module ==
        self.inter_sne = create_streamNevent_full()

        # Setup for q
        q_config = {}
        self.inter_bsel1 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.inter_targ1 = inter_dict['q']['targ']
        q_config['bsel'] = self.inter_bsel1
        q_config['targ'] = self.inter_targ1
        if inter_class_dict['q'] == selectorbg_q_gemv:
            self.inter_jl1 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            q_config['jl'] = self.inter_jl1
        elif inter_class_dict['q'] == selectorbg_q_linreg:
            self.inter_a1 = inter_dict['q']['a']
            self.inter_b1 = inter_dict['q']['b']
            q_config['a'] = self.inter_a1
            q_config['b'] = self.inter_b1

        # Instantiate q selector
        self.q_selector = inter_class_dict['q'](q_config, self.inter_sne, inter_dict['q']['b_lh'][0], inter_dict['q']['b_lh'][1])

        # Setup for k
        k_config = {}
        self.inter_bsel2 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.inter_targ2 = inter_dict['k']['targ']
        k_config['bsel'] = self.inter_bsel2
        k_config['targ'] = self.inter_targ2
        if inter_class_dict['k'] == selectorbg_k_gemv:
            self.inter_jl2 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            k_config['jl'] = self.inter_jl2
        elif inter_class_dict['k'] == selectorbg_k_linreg:
            self.inter_a2 = inter_dict['k']['a']
            self.inter_b2 = inter_dict['k']['b']
            k_config['a'] = self.inter_a2
            k_config['b'] = self.inter_b2

        # Instantiate k selector
        self.k_selector = inter_class_dict['k'](k_config, self.inter_sne, inter_dict['k']['b_lh'][0], inter_dict['k']['b_lh'][1])

        # Setup for v
        v_config = {}
        self.inter_bsel3 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.inter_targ3 = inter_dict['v']['targ']
        v_config['bsel'] = self.inter_bsel3
        v_config['targ'] = self.inter_targ3
        if inter_class_dict['v'] == selectorbg_v_gemv:
            self.inter_jl3 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            v_config['jl'] = self.inter_jl3
        elif inter_class_dict['v'] == selectorbg_v_linreg:
            self.inter_a3 = inter_dict['v']['a']
            self.inter_b3 = inter_dict['v']['b']
            v_config['a'] = self.inter_a3
            v_config['b'] = self.inter_b3

        # Instantiate v selector
        self.v_selector = inter_class_dict['v'](v_config, self.inter_sne, inter_dict['v']['b_lh'][0], inter_dict['v']['b_lh'][1])

        # ================================================


        # My result from prev layer
        self.prev_inter_sne = None
        self.prev_inter_bsel1 = None
        self.prev_inter_bsel2 = None
        self.prev_inter_bsel3 = None

        if "llama" in config.model_name.lower() or "mistral" in config.model_name.lower():
            self.input_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None
        elif "phi" in config.model_name.lower() or "qwen" in config.model_name.lower():
            self.input_layernorm = Phi3RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = Phi3RMSNorm_dual(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None

    def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
        # gemvNormTH(x, self.jl, self.intra_bsel, self.targ, self.intra_sne)
        lnx = self.input_layernorm(x)
        # gemvNormTH2(lnx[1], self.intra_jl1, self.intra_jl2, 
        #             self.intra_bsel1, self.intra_bsel2, 
        #             self.intra_targ1, self.intra_targ2, self.intra_sne)
        self.gate_selector.compare(lnx[1])
        self.up_selector.compare(lnx[1])
        h = x + self.attention(lnx[0], 
                                freqs_cis, mask, 
                                self.prev_inter_bsel1, self.prev_inter_bsel2, self.prev_inter_bsel3,
                                self.prev_inter_sne, 
                                input_pos
                                )

        if self.pre_feedforward_layernorm != None:
            h = self.pre_feedforward_layernorm(h)
        
        ffn_in = self.post_attention_layernorm(h)
        # gemvNormTH3(ffn_in[1], self.inter_jl1, self.inter_jl2, self.inter_jl3,
        #             self.inter_bsel1, self.inter_bsel2, self.inter_bsel3,
        #             self.inter_targ1, self.inter_targ2, self.inter_targ3, self.inter_sne)
        self.q_selector.compare(ffn_in[1])
        self.k_selector.compare(ffn_in[1])
        self.v_selector.compare(ffn_in[1])
        out = self.feed_forward(ffn_in[0], self.intra_bsel1, self.intra_bsel2, self.intra_sne)

        if self.post_feedforward_layernorm != None:
            out = self.post_feedforward_layernorm(out)

        out = h + out 

        return out

class TransformerBlockLast(nn.Module):
    def __init__(self, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None,
                intra_class_dict={}, intra_dict={},
                wo_class=selector_gemv, wo_dict={},
                w2_class=selector_gemv, w2_dict={}) -> None:
        super().__init__()
        self.attention = Attention(config, linear_class, linear_kwargs, wo_class, wo_dict)
        self.feed_forward = FeedForward(config, linear_class, linear_kwargs, w2_class, w2_dict)

        # == My calculations from first residual to FFN ==
        self.intra_sne = create_streamNevent_full()

        # Setup for gate
        gate_config = {}
        self.intra_bsel1 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.intra_targ1 = intra_dict['gate']['targ']
        gate_config['bsel'] = self.intra_bsel1
        gate_config['targ'] = self.intra_targ1
        if intra_class_dict['gate'] == selectorbg_g_gemv:
            self.intra_jl1 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            gate_config['jl'] = self.intra_jl1
        elif intra_class_dict['gate'] == selectorbg_g_linreg:
            self.intra_a1 = intra_dict['gate']['a']
            self.intra_b1 = intra_dict['gate']['b']
            gate_config['a'] = self.intra_a1
            gate_config['b'] = self.intra_b1

        # Instantiate gate selector
        self.gate_selector = intra_class_dict['gate'](gate_config, self.intra_sne, intra_dict['gate']['b_lh'][0], intra_dict['gate']['b_lh'][1])

        # Setup for up
        up_config = {}
        self.intra_bsel2 = torch.empty((1,), dtype=torch.int, device="cuda")
        self.intra_targ2 = intra_dict['up']['targ']
        up_config['bsel'] = self.intra_bsel2
        up_config['targ'] = self.intra_targ2
        if intra_class_dict['up'] == selectorbg_u_gemv:
            self.intra_jl2 = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            up_config['jl'] = self.intra_jl2
        elif intra_class_dict['up'] == selectorbg_u_linreg:
            self.intra_a2 = intra_dict['up']['a']
            self.intra_b2 = intra_dict['up']['b']
            up_config['a'] = self.intra_a2
            up_config['b'] = self.intra_b2

        # Instantiate up selector
        self.up_selector = intra_class_dict['up'](up_config, self.intra_sne, intra_dict['up']['b_lh'][0], intra_dict['up']['b_lh'][1])
        # ================================================


        # My result from prev layer
        self.prev_inter_sne = None
        self.prev_inter_bsel1 = None
        self.prev_inter_bsel2 = None
        self.prev_inter_bsel3 = None

        if "llama" in config.model_name.lower() or "mistral" in config.model_name.lower():
            self.input_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = RMSNorm(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None
        elif "phi" in config.model_name.lower() or "qwen" in config.model_name.lower():
            self.input_layernorm = Phi3RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = Phi3RMSNorm(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None

    def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
        # gemvNormTH(x, self.jl, self.intra_bsel, self.targ, self.intra_sne)

        lnx = self.input_layernorm(x)
        # gemvNormTH2(lnx[1], self.intra_jl1, self.intra_jl2, 
        #             self.intra_bsel1, self.intra_bsel2, 
        #             self.intra_targ1, self.intra_targ2, self.intra_sne)
        self.gate_selector.compare(lnx[1])
        self.up_selector.compare(lnx[1])
        h = x + self.attention(lnx[0], 
                                freqs_cis, mask, 
                                self.prev_inter_bsel1, self.prev_inter_bsel2, self.prev_inter_bsel3,
                                self.prev_inter_sne, 
                                input_pos
                                )

        if self.pre_feedforward_layernorm != None:
            h = self.pre_feedforward_layernorm(h)

        ffn_in = self.post_attention_layernorm(h)
        out = self.feed_forward(ffn_in, self.intra_bsel1, self.intra_bsel2, self.intra_sne)

        if self.post_feedforward_layernorm != None:
            out = self.post_feedforward_layernorm(out)

        out = h + out 

        return out


class TransformerBlockFirst_fused(nn.Module):
    def __init__(self, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None,
                intra_class_dict={}, intra_dict={}, 
                inter_class_dict={}, inter_dict={}, 
                trigger_class_dict={}, trigger_dict={},
                wo_class=selector_gemv, wo_dict={},
                w2_class=selector_gemv, w2_dict={}) -> None:
        super().__init__()
        self.attention = Attention_fused(config, linear_class, linear_kwargs, wo_class, wo_dict)
        self.feed_forward = FeedForward_fused(config, linear_class, linear_kwargs, w2_class, w2_dict)

        # == My calculations from first residual to FFN ==
        self.intra_sne = create_streamNevent_full()

        # Setup for gateup
        gate_up_config = {}
        self.intra_bsel = torch.empty((1,), dtype=torch.int, device="cuda")
        self.intra_targ = intra_dict['gate_up']['targ']
        gate_up_config['bsel'] = self.intra_bsel
        gate_up_config['targ'] = self.intra_targ
        if intra_class_dict['gate_up'] == selectorbg_gu_gemv:
            self.intra_jl = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            gate_up_config['jl'] = self.intra_jl
        elif intra_class_dict['gate_up'] == selectorbg_gu_linreg:
            self.intra_a = intra_dict['gate_up']['a']
            self.intra_b = intra_dict['gate_up']['b']
            gate_up_config['a'] = self.intra_a
            gate_up_config['b'] = self.intra_b

        # Instantiate gateup selector
        self.gate_up_selector = intra_class_dict['gate_up'](gate_up_config, self.intra_sne, intra_dict['gate_up']['b_lh'][0], intra_dict['gate_up']['b_lh'][1])
        # ================================================


        # == Calculations for next layer's attn module ==
        self.inter_sne = create_streamNevent_full()

        # Setup for qkv
        qkv_config = {}
        self.inter_bsel = torch.empty((1,), dtype=torch.int, device="cuda")
        self.inter_targ = inter_dict['qkv']['targ']
        qkv_config['bsel'] = self.inter_bsel
        qkv_config['targ'] = self.inter_targ
        if inter_class_dict['qkv'] == selectorbg_qkv_gemv:
            self.inter_jl = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            qkv_config['jl'] = self.inter_jl
        elif inter_class_dict['qkv'] == selectorbg_qkv_linreg:
            self.inter_a = inter_dict['qkv']['a']
            self.inter_b = inter_dict['qkv']['b']
            qkv_config['a'] = self.inter_a
            qkv_config['b'] = self.inter_b

        # Instantiate qkv selector
        self.qkv_selector = inter_class_dict['qkv'](qkv_config, self.inter_sne, inter_dict['qkv']['b_lh'][0], inter_dict['qkv']['b_lh'][1])
        # ================================================

        # == Triggers ==
        self.trigger_sne = create_streamNevent_full()

        # Setup for trigger qkv
        qkv_config = {}
        self.trigger_targ = trigger_dict['qkv']['targ']
        qkv_config['targ'] = self.trigger_targ
        if trigger_class_dict['qkv'] == trigger_gemv:
            self.trigger_jl = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            qkv_config['jl'] = self.trigger_jl
        elif trigger_class_dict['qkv'] == trigger_linreg:
            self.trigger_a = trigger_dict['qkv']['a']
            self.trigger_b = trigger_dict['qkv']['b']
            qkv_config['a'] = self.trigger_a
            qkv_config['b'] = self.trigger_b

        # Instantiate trigger qkv
        self.trigger_qkv = trigger_class_dict['qkv'](qkv_config, self.trigger_sne, trigger_dict['qkv']['b_lh'][0], trigger_dict['qkv']['b_lh'][1])

        # =======================

        if "llama" in config.model_name.lower() or "mistral" in config.model_name.lower():
            self.input_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None
        elif "phi" in config.model_name.lower():
            self.input_layernorm = Phi3RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = Phi3RMSNorm_dual(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None

    def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
        # gemvNormTH3Full(x, self.trigger_jl1, self.trigger_jl2, self.trigger_jl3,
        #                 self.trigger_res1, self.trigger_res2, self.trigger_res3, 
        #                 self.trigger_bsel1, self.trigger_bsel2, self.trigger_bsel3,
        #                 self.trigger_targ1, self.trigger_targ2, self.trigger_targ3,
        #                 self.trigger_sne)
        
        # gemvNormTH(x, self.jl, self.intra_bsel, self.targ, self.intra_sne)
        lnx = self.input_layernorm(x)

        trigger_bsel = self.trigger_qkv.compare(lnx[1])
        fakeTrigger(self.trigger_sne)

        # gemvNormTH2(lnx[1], self.intra_jl1, self.intra_jl2, 
        #             self.intra_bsel1, self.intra_bsel2, 
        #             self.intra_targ1, self.intra_targ2, self.intra_sne)
        self.gate_up_selector.compare(lnx[1])

        h = x + self.attention(lnx[0], 
                                freqs_cis, mask, 
                                trigger_bsel,
                                self.trigger_sne, 
                                input_pos
                                )

        if self.pre_feedforward_layernorm != None:
            h = self.pre_feedforward_layernorm(h)
        
        ffn_in = self.post_attention_layernorm(h)
        # gemvNormTH3(ffn_in[1], self.inter_jl1, self.inter_jl2, self.inter_jl3,
        #             self.inter_bsel1, self.inter_bsel2, self.inter_bsel3,
        #             self.inter_targ1, self.inter_targ2, self.inter_targ3, self.inter_sne)
        self.qkv_selector.compare(ffn_in[1])
        out = self.feed_forward(ffn_in[0], self.intra_bsel, self.intra_sne)

        if self.post_feedforward_layernorm != None:
            out = self.post_feedforward_layernorm(out)

        out = h + out 

        return out

class TransformerBlock_fused(nn.Module):
    def __init__(self, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None,
                intra_class_dict={}, intra_dict={}, 
                inter_class_dict={}, inter_dict={},
                wo_class=selector_gemv, wo_dict={},
                w2_class=selector_gemv, w2_dict={}) -> None:
        super().__init__()
        self.attention = Attention_fused(config, linear_class, linear_kwargs, wo_class, wo_dict)
        self.feed_forward = FeedForward_fused(config, linear_class, linear_kwargs, w2_class, w2_dict)

        # == My calculations from first residual to FFN ==
        self.intra_sne = create_streamNevent_full()

        # Setup for gateup
        gate_up_config = {}
        self.intra_bsel = torch.empty((1,), dtype=torch.int, device="cuda")
        self.intra_targ = intra_dict['gate_up']['targ']
        gate_up_config['bsel'] = self.intra_bsel
        gate_up_config['targ'] = self.intra_targ
        if intra_class_dict['gate_up'] == selectorbg_gu_gemv:
            self.intra_jl = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            gate_up_config['jl'] = self.intra_jl
        elif intra_class_dict['gate_up'] == selectorbg_gu_linreg:
            self.intra_a = intra_dict['gate_up']['a']
            self.intra_b = intra_dict['gate_up']['b']
            gate_up_config['a'] = self.intra_a
            gate_up_config['b'] = self.intra_b

        # Instantiate gateup selector
        self.gate_up_selector = intra_class_dict['gate_up'](gate_up_config, self.intra_sne, intra_dict['gate_up']['b_lh'][0], intra_dict['gate_up']['b_lh'][1])
        # ================================================


        # == Calculations for next layer's attn module ==
        self.inter_sne = create_streamNevent_full()

        # Setup for qkv
        qkv_config = {}
        self.inter_bsel = torch.empty((1,), dtype=torch.int, device="cuda")
        self.inter_targ = inter_dict['qkv']['targ']
        qkv_config['bsel'] = self.inter_bsel
        qkv_config['targ'] = self.inter_targ
        if inter_class_dict['qkv'] == selectorbg_qkv_gemv:
            self.inter_jl = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            qkv_config['jl'] = self.inter_jl
        elif inter_class_dict['qkv'] == selectorbg_qkv_linreg:
            self.inter_a = inter_dict['qkv']['a']
            self.inter_b = inter_dict['qkv']['b']
            qkv_config['a'] = self.inter_a
            qkv_config['b'] = self.inter_b

        # Instantiate qkv selector
        self.qkv_selector = inter_class_dict['qkv'](qkv_config, self.inter_sne, inter_dict['qkv']['b_lh'][0], inter_dict['qkv']['b_lh'][1])
        # ================================================

        # My result from prev layer
        self.prev_inter_sne = None
        self.prev_inter_bsel = None

        if "llama" in config.model_name.lower() or "mistral" in config.model_name.lower():
            self.input_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None
        elif "phi" in config.model_name.lower():
            self.input_layernorm = Phi3RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = Phi3RMSNorm_dual(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None

    def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
        # gemvNormTH3Full(x, self.trigger_jl1, self.trigger_jl2, self.trigger_jl3,
        #                 self.trigger_res1, self.trigger_res2, self.trigger_res3, 
        #                 self.trigger_bsel1, self.trigger_bsel2, self.trigger_bsel3,
        #                 self.trigger_targ1, self.trigger_targ2, self.trigger_targ3,
        #                 self.trigger_sne)
        
        # gemvNormTH(x, self.jl, self.intra_bsel, self.targ, self.intra_sne)
        lnx = self.input_layernorm(x)

        # gemvNormTH2(lnx[1], self.intra_jl1, self.intra_jl2, 
        #             self.intra_bsel1, self.intra_bsel2, 
        #             self.intra_targ1, self.intra_targ2, self.intra_sne)
        self.gate_up_selector.compare(lnx[1])

        h = x + self.attention(lnx[0], 
                                freqs_cis, mask, 
                                self.prev_inter_bsel,
                                self.prev_inter_sne, 
                                input_pos
                                )

        if self.pre_feedforward_layernorm != None:
            h = self.pre_feedforward_layernorm(h)
        
        ffn_in = self.post_attention_layernorm(h)
        # gemvNormTH3(ffn_in[1], self.inter_jl1, self.inter_jl2, self.inter_jl3,
        #             self.inter_bsel1, self.inter_bsel2, self.inter_bsel3,
        #             self.inter_targ1, self.inter_targ2, self.inter_targ3, self.inter_sne)
        self.qkv_selector.compare(ffn_in[1])
        out = self.feed_forward(ffn_in[0], self.intra_bsel, self.intra_sne)

        if self.post_feedforward_layernorm != None:
            out = self.post_feedforward_layernorm(out)

        out = h + out 

        return out

class TransformerBlockLast_fused(nn.Module):
    def __init__(self, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None,
                intra_class_dict={}, intra_dict={}, 
                wo_class=selector_gemv, wo_dict={},
                w2_class=selector_gemv, w2_dict={}) -> None:
        super().__init__()
        self.attention = Attention_fused(config, linear_class, linear_kwargs, wo_class, wo_dict)
        self.feed_forward = FeedForward_fused(config, linear_class, linear_kwargs, w2_class, w2_dict)

        # == My calculations from first residual to FFN ==
        self.intra_sne = create_streamNevent_full()

        # Setup for gateup
        gate_up_config = {}
        self.intra_bsel = torch.empty((1,), dtype=torch.int, device="cuda")
        self.intra_targ = intra_dict['gate_up']['targ']
        gate_up_config['bsel'] = self.intra_bsel
        gate_up_config['targ'] = self.intra_targ
        if intra_class_dict['gate_up'] == selectorbg_gu_gemv:
            self.intra_jl = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            gate_up_config['jl'] = self.intra_jl
        elif intra_class_dict['gate_up'] == selectorbg_gu_linreg:
            self.intra_a = intra_dict['gate_up']['a']
            self.intra_b = intra_dict['gate_up']['b']
            gate_up_config['a'] = self.intra_a
            gate_up_config['b'] = self.intra_b

        # Instantiate gateup selector
        self.gate_up_selector = intra_class_dict['gate_up'](gate_up_config, self.intra_sne, intra_dict['gate_up']['b_lh'][0], intra_dict['gate_up']['b_lh'][1])
        # ================================================

        # My result from prev layer
        self.prev_inter_sne = None
        self.prev_inter_bsel = None

        if "llama" in config.model_name.lower() or "mistral" in config.model_name.lower():
            self.input_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = RMSNorm_dual(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None
        elif "phi" in config.model_name.lower():
            self.input_layernorm = Phi3RMSNorm_dual(config.dim, config.norm_eps)
            self.post_attention_layernorm = Phi3RMSNorm(config.dim, config.norm_eps)
            self.pre_feedforward_layernorm = None
            self.post_feedforward_layernorm = None

    def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
        # gemvNormTH3Full(x, self.trigger_jl1, self.trigger_jl2, self.trigger_jl3,
        #                 self.trigger_res1, self.trigger_res2, self.trigger_res3, 
        #                 self.trigger_bsel1, self.trigger_bsel2, self.trigger_bsel3,
        #                 self.trigger_targ1, self.trigger_targ2, self.trigger_targ3,
        #                 self.trigger_sne)
        
        # gemvNormTH(x, self.jl, self.intra_bsel, self.targ, self.intra_sne)
        lnx = self.input_layernorm(x)

        # gemvNormTH2(lnx[1], self.intra_jl1, self.intra_jl2, 
        #             self.intra_bsel1, self.intra_bsel2, 
        #             self.intra_targ1, self.intra_targ2, self.intra_sne)
        self.gate_up_selector.compare(lnx[1])

        h = x + self.attention(lnx[0], 
                                freqs_cis, mask, 
                                self.prev_inter_bsel,
                                self.prev_inter_sne, 
                                input_pos
                                )

        if self.pre_feedforward_layernorm != None:
            h = self.pre_feedforward_layernorm(h)
        
        ffn_in = self.post_attention_layernorm(h)
        out = self.feed_forward(ffn_in, self.intra_bsel, self.intra_sne)

        if self.post_feedforward_layernorm != None:
            out = self.post_feedforward_layernorm(out)

        out = h + out 

        return out



class Attention(nn.Module):
    def __init__(self, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None,
                 wo_class=selector_linreg, wo_dict={}) -> None:
        super().__init__()
        assert config.dim % config.n_head == 0

        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
        self.wq = APLinearSel(config.dim, config.n_head * config.head_dim, bias=False, **(linear_kwargs or {}))
        self.wk = APLinearSel(config.dim, config.n_local_heads * config.head_dim, bias=False, **(linear_kwargs or {}))
        self.wv = APLinearSel(config.dim, config.n_local_heads * config.head_dim, bias=False, **(linear_kwargs or {}))
        self.wo = APLinearSel(config.dim, config.dim, bias=False, **(linear_kwargs or {}))

        self.wo_sne = create_streamNevent_full()
        if wo_class == selector_gemv:
            self.wo_jl = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            wo_dict["jl"] = self.wo_jl
        self.wo_selector = wo_class(wo_dict, self.wo_sne, wo_dict["b_lh"][0], wo_dict["b_lh"][1])

        self.kv_cache = None

        self.n_head = config.n_head
        self.head_dim = config.head_dim
        self.n_local_heads = config.n_local_heads
        self.dim = config.dim
        self._register_load_state_dict_pre_hook(self.load_hook)
        self.config = config

        self.scaling = 1/ math.sqrt(config.head_dim)

        if "phi" in config.model_name.lower():
            self.rotary_emb = Phi3RotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=config.block_size,
                    base=config.rope_base
            )
            self.sdpa_scaling = None
        elif "qwen" in config.model_name.lower():
            self.rotary_emb = QwenRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=config.block_size,
                    base=config.rope_base
            )
            self.sdpa_scaling = None
        else:
            self.rotary_emb = None
            self.sdpa_scaling = None


    def load_hook(self, state_dict, prefix, *args):
        if prefix + "wq.weight" in state_dict:
            wq = state_dict.pop(prefix + "wq.weight")
            wk = state_dict.pop(prefix + "wk.weight")
            wv = state_dict.pop(prefix + "wv.weight")
            state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
    
    def _attn_flash_with_cache(self, bsz, seqlen, q, k, v, cache_seqlens):
        attn_output = flash_attn_with_cache_forward(
            q,
            k,
            v,
            cache_seqlens = cache_seqlens,
            softmax_scale = self.scaling,
        )
        attn_output = attn_output.reshape((bsz, seqlen, self.n_head * self.head_dim))
        return attn_output
    
    def _attn_flash(self, bsz, seqlen, q, k, v):
        attn_output = flash_attn_forward(
            q,
            k,
            v,
            softmax_scale = self.scaling,
        )
        attn_output = attn_output.reshape((bsz, seqlen, self.n_head * self.head_dim))
        return attn_output

    def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, 
                bsel1: Tensor, bsel2: Tensor, bsel3: Tensor, sne: int,
                input_pos: Optional[Tensor] = None,
                ) -> Tensor:
        bsz, seqlen, _ = x.shape

        kv_size = self.n_local_heads * self.head_dim
        # q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
        q = self.wq(x, bsel1, sne)
        k = self.wk(x, bsel2, sne)
        v = self.wv(x, bsel3, sne)


        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
        k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        
        if self.rotary_emb == None:
            q = apply_rotary_emb(q, freqs_cis)
            k = apply_rotary_emb(k, freqs_cis)
        else:
            cos, sin = self.rotary_emb(v, input_pos.unsqueeze(0))
            q, k = apply_rotary_pos_emb(q, k, cos, sin)

        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

        if self.kv_cache is not None:
            k, v = self.kv_cache.update(input_pos, k, v)

        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)

        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) #, scale=self.sdpa_scaling)

        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

        wo_bsel = self.wo_selector.compare(y)
        y = self.wo(y, wo_bsel, self.wo_sne)

        return y

class Attention_fused(nn.Module):
    def __init__(self, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None,
                 wo_class=selector_linreg, wo_dict={}) -> None:
        super().__init__()
        assert config.dim % config.n_head == 0

        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
        self.wqkv = APLinearSel(config.dim, total_head_dim, bias=False, **(linear_kwargs or {}))
        self.wo = APLinearSel(config.dim, config.dim, bias=False, **(linear_kwargs or {}))

        self.wo_sne = create_streamNevent_full()
        if wo_class == selector_gemv:
            self.wo_jl = torch.empty((64, config.dim), dtype=torch.float16, device='cuda')
            wo_dict["jl"] = self.wo_jl
        self.wo_selector = wo_class(wo_dict, self.wo_sne, wo_dict["b_lh"][0], wo_dict["b_lh"][1])

        self.kv_cache = None

        self.n_head = config.n_head
        self.head_dim = config.head_dim
        self.n_local_heads = config.n_local_heads
        self.dim = config.dim
        self._register_load_state_dict_pre_hook(self.load_hook)
        self.config = config

        self.scaling = 1/ math.sqrt(config.head_dim)

        if "phi" in config.model_name.lower():
            self.rotary_emb = Phi3RotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=config.block_size,
                    base=config.rope_base
            )
            self.sdpa_scaling = None
        else:
            self.rotary_emb = None
            self.sdpa_scaling = None


    def load_hook(self, state_dict, prefix, *args):
        if prefix + "wq.weight" in state_dict:
            wq = state_dict.pop(prefix + "wq.weight")
            wk = state_dict.pop(prefix + "wk.weight")
            wv = state_dict.pop(prefix + "wv.weight")
            state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
    
    def _attn_flash_with_cache(self, bsz, seqlen, q, k, v, cache_seqlens):
        attn_output = flash_attn_with_cache_forward(
            q,
            k,
            v,
            cache_seqlens = cache_seqlens,
            softmax_scale = self.scaling,
        )
        attn_output = attn_output.reshape((bsz, seqlen, self.n_head * self.head_dim))
        return attn_output
    
    def _attn_flash(self, bsz, seqlen, q, k, v):
        attn_output = flash_attn_forward(
            q,
            k,
            v,
            softmax_scale = self.scaling,
        )
        attn_output = attn_output.reshape((bsz, seqlen, self.n_head * self.head_dim))
        return attn_output

    def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, 
                bsel: Tensor, sne: int,
                input_pos: Optional[Tensor] = None,
                ) -> Tensor:
        bsz, seqlen, _ = x.shape

        kv_size = self.n_local_heads * self.head_dim
        q, k, v = self.wqkv(x, bsel, sne).split([self.dim, kv_size, kv_size], dim=-1)

        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
        k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        
        if self.rotary_emb == None:
            q = apply_rotary_emb(q, freqs_cis)
            k = apply_rotary_emb(k, freqs_cis)
        else:
            cos, sin = self.rotary_emb(v, input_pos.unsqueeze(0))
            q, k = apply_rotary_pos_emb(q, k, cos, sin)

        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

        if self.kv_cache is not None:
            k, v = self.kv_cache.update(input_pos, k, v)

        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)

        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) #, scale=self.sdpa_scaling)

        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

        wo_bsel = self.wo_selector.compare(y)
        y = self.wo(y, wo_bsel, self.wo_sne)

        return y

class FeedForward(nn.Module):
    def __init__(self, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None,
                 w2_class=selector_linreg, w2_dict={}) -> None:
        super().__init__()
        self.config = config
        # self.w1w3 = linear_class(config.dim, config.intermediate_size*2, bias=False, **(linear_kwargs or {}))
        # self.w1 = linear_class(config.dim, config.intermediate_size, bias=False, **(linear_kwargs or {}))
        # self.w3 = linear_class(config.dim, config.intermediate_size, bias=False, **(linear_kwargs or {}))
        # self.w2 = linear_class(config.intermediate_size, config.dim, bias=False, **(linear_kwargs or {}))
        self.w1 = APLinearSel(config.dim, config.intermediate_size, bias=False, **(linear_kwargs or {}))
        self.w3 = APLinearSel(config.dim, config.intermediate_size, bias=False, **(linear_kwargs or {}))
        self.w2 = APLinearSel(config.intermediate_size, config.dim, bias=False, **(linear_kwargs or {}))

        self.w2_sne = create_streamNevent_full()
        if w2_class == selector_gemv:
            self.w2_jl = torch.empty((64, config.intermediate_size), dtype=torch.float16, device='cuda')
            w2_dict["jl"] = self.w2_jl
        self.w2_selector = w2_class(w2_dict, self.w2_sne, w2_dict["b_lh"][0], w2_dict["b_lh"][1])

        self.act_fn = F.silu

    def forward(self, x: Tensor, bsel1: Tensor, bsel2: Tensor, sne: int) -> Tensor:
        # w1_out, w3_out = self.w1w3(x).split([self.config.intermediate_size, self.config.intermediate_size], dim=-1)
        w1_out = self.w1(x, bsel1, sne)
        w3_out = self.w3(x, bsel2, sne)
        y = self.act_fn(w1_out) * w3_out

        w2_bsel = self.w2_selector.compare(y)
        return self.w2(y, w2_bsel, self.w2_sne)

class FeedForward_fused(nn.Module):
    def __init__(self, config: ModelArgs, linear_class=nn.Linear, linear_kwargs=None,
                 w2_class=selector_linreg, w2_dict={}) -> None:
        super().__init__()
        self.config = config
        self.w1w3 = APLinearSel(config.dim, config.intermediate_size*2, bias=False, **(linear_kwargs or {}))
        self.w2 = APLinearSel(config.intermediate_size, config.dim, bias=False, **(linear_kwargs or {}))

        self.w2_sne = create_streamNevent_full()
        if w2_class == selector_gemv:
            self.w2_jl = torch.empty((64, config.intermediate_size), dtype=torch.float16, device='cuda')
            w2_dict["jl"] = self.w2_jl
        self.w2_selector = w2_class(w2_dict, self.w2_sne, w2_dict["b_lh"][0], w2_dict["b_lh"][1])

        self.act_fn = F.silu

    def forward(self, x: Tensor, bsel: Tensor, sne: int) -> Tensor:
        w1_out, w3_out = self.w1w3(x, bsel, sne).split([self.config.intermediate_size, self.config.intermediate_size], dim=-1)
        y = self.act_fn(w1_out) * w3_out

        w2_bsel = self.w2_selector.compare(y)
        return self.w2(y, w2_bsel, self.w2_sne)
    


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

class Phi3RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.register_buffer("inv_freq", None, persistent=False)

    @torch.no_grad()
    def forward(self, x, position_ids, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # if self.inv_freq is None:
        self.inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
        )
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

class Gemma2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
        self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

    @torch.no_grad()
    def forward(self, x, position_ids, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        self.inv_freq.to(x.device)
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

class QwenRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
        self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

    @torch.no_grad()
    def forward(self, x, position_ids, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        self.inv_freq.to(x.device)
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

class Phi3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Phi3RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class Phi3RMSNorm_dual(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Phi3RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.weight_next = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return (self.weight * hidden_states.to(input_dtype), self.weight_next * hidden_states.to(input_dtype))
    
class Gemma2RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float())
        # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
        output = output * (1.0 + self.weight.float())
        return output.type_as(x)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.eps}"

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

    def forward(self, x: Tensor) -> Tensor:
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class RMSNorm_dual(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        self.weight_next = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

    def forward(self, x: Tensor) -> Tensor:
        output = self._norm(x.float()).type_as(x)
        return (output * self.weight, output * self.weight_next)

def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None):
    factor = rope_scaling["factor"]
    low_freq_factor = rope_scaling["low_freq_factor"]
    high_freq_factor = rope_scaling["high_freq_factor"]
    old_context_len = rope_scaling["original_max_position_embeddings"]

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * math.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
            new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
    seq_len: int, n_elem: int, base: int = 10000,
    dtype: torch.dtype = torch.bfloat16,
    rope_scaling: Optional[dict] = None,
) -> Tensor:
    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
    if rope_scaling is not None:
        freqs = apply_rope_scaling(freqs, rope_scaling)
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
    return cache.to(dtype=dtype)


def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
        ],
        -1,
    )

    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

