from typing import Callable, Optional, Union
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import SiLUActivation, GELUTanh, NewGELUActivation
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
from transformers.processing_utils import Unpack
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from transformers.utils import TransformersKwargs
import logging
import math
import time
from transformers.modeling_utils import (
    ALL_ATTENTION_FUNCTIONS
)
from transformers import (
    Qwen3ForCausalLM, 
    LlamaForCausalLM, 
    GPT2LMHeadModel,
    Gemma3ForCausalLM
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions
)
from transformers.masking_utils import (
    create_causal_mask, 
    create_sliding_window_causal_mask
)
from transformers.models.gemma3.modeling_gemma3 import (
    _bidirectional_window_overlay,
)
import transformers


# Import SGX obfuscation module
from sgx_api import sgx_groupcover


# Custom formatter with colors for different log levels
class ColoredFormatter(logging.Formatter):
    # ANSI color codes
    COLORS = {
        logging.DEBUG: '\033[94m',    # Blue
        logging.INFO: '\033[92m',     # Green
        logging.WARNING: '\033[93m',  # Yellow
        logging.ERROR: '\033[91m',    # Red
        logging.CRITICAL: '\033[95m', # Purple
        'RESET': '\033[0m'            # Reset color
    }
    
    def format(self, record):
        # Get the color for this log level
        color = self.COLORS.get(record.levelno, self.COLORS['RESET'])
        # Format the record with color
        original_formatter = logging.Formatter(f'{color}[%(levelname)s]{self.COLORS["RESET"]} %(name)s: %(message)s')
        return original_formatter.format(record)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = ColoredFormatter()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


def _restore(y, blocks, permutation, bias=None):    
    if bias is not None:
        y = y - bias
    
    inv_perm = torch.argsort(permutation)
    y = y[:,:,inv_perm]
    
    current_start = 0
    for block in blocks:
        block_size = block.shape[0]
        end = current_start + block_size
        
        y[:, :, current_start:end] = y[:, :, current_start:end] @ block
        
        current_start = end
    
    if bias is not None:
        y = y + bias
    return y


class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
    def __init__(self, config, obf_param, simulate=True):
        super().__init__(config)
        self.simulate = simulate
        if self.simulate:
            self.q_proj_param = []
            for layer_param in obf_param["q_proj"]:
                self.q_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.k_proj_param = []
            for layer_param in obf_param["k_proj"]:
                self.k_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.v_proj_param = []
            for layer_param in obf_param["v_proj"]:
                self.v_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.o_proj_param = []
            for layer_param in obf_param["o_proj"]:
                self.o_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.gate_proj_param = []
            for layer_param in obf_param["gate_proj"]:
                self.gate_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.up_proj_param = []
            for layer_param in obf_param["up_proj"]:
                self.up_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.down_proj_param = []
            for layer_param in obf_param["down_proj"]:
                self.down_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
        else:
            self.sgx = sgx_groupcover.get_sgx_instance()
            
            if not self.sgx.prepare_obf_params(obf_param):
                raise RuntimeError("Failed to prepare obfuscation parameters in SGX")
            
            
    def prepare_norm_params(self):
        if not self.sgx.prepare_norm_params(self.model):
            raise RuntimeError("Failed to prepare normalization parameters in SGX")
    
    def _custom_layernorm(self, layer_idx: int, norm_type: str, hidden_states):
        if self.simulate:
            match norm_type:
                case "input_layernorm":
                    hidden_states = self.model.layers[layer_idx].input_layernorm(hidden_states)
                case "post_attention_layernorm":
                    hidden_states = self.model.layers[layer_idx].post_attention_layernorm(hidden_states)
                case "q_norm":
                    hidden_states = self.model.layers[layer_idx].self_attn.q_norm(hidden_states)
                case "k_norm":
                    hidden_states = self.model.layers[layer_idx].self_attn.k_norm(hidden_states)
                case "norm":
                    hidden_states = self.model.norm(hidden_states)
                case _:
                    raise ValueError(f"No such type norm: {norm_type}")
        else:
            hidden_states = self.sgx.norm(hidden_states, layer_idx, norm_type)
        return hidden_states
    
    def _custom_attn(
        self,
        layer_idx,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        self_attn = self.model.layers[layer_idx].self_attn
        
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self_attn.head_dim)

        query_states = self_attn.q_proj(hidden_states)
        key_states = self_attn.k_proj(hidden_states)
        value_states = self_attn.v_proj(hidden_states)
        
        if self.simulate:
            query_states = _restore(query_states, **self.q_proj_param[layer_idx])
            key_states = _restore(key_states, **self.k_proj_param[layer_idx])
            value_states = _restore(value_states, **self.v_proj_param[layer_idx])
            
            query_states = self._custom_layernorm(layer_idx, "q_norm", query_states.view(hidden_shape))
            key_states = self._custom_layernorm(layer_idx, "k_norm", key_states.view(hidden_shape))
        else:
            query_states = self.sgx.restore(
                query_states, 
                layer_idx,
                "q_proj", 
                self_attn.q_proj.bias if hasattr(self_attn.q_proj, "bias") else None
            )
            key_states = self.sgx.restore(
                key_states, 
                layer_idx,
                "k_proj", 
                self_attn.k_proj.bias if hasattr(self_attn.k_proj, "bias") else None
            )
            value_states = self.sgx.restore(
                value_states, 
                layer_idx,
                "v_proj", 
                self_attn.v_proj.bias if hasattr(self_attn.v_proj, "bias") else None
            )
            
            query_states = self._custom_layernorm(layer_idx, "q_norm", query_states.view(hidden_shape))
            key_states = self._custom_layernorm(layer_idx, "k_norm", key_states.view(hidden_shape))
        
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.view(hidden_shape).transpose(1, 2)
        

        cos, sin = position_embeddings
        query_states, key_states = transformers.models.qwen3.modeling_qwen3.apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self_attn.layer_idx, cache_kwargs)

        attention_interface: Callable = transformers.models.qwen3.modeling_qwen3.eager_attention_forward
        if self_attn.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self_attn.config._attn_implementation]

        attn_output, _ = attention_interface(
            self_attn,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self_attn.training else self_attn.attention_dropout,
            scaling=self_attn.scaling,
            sliding_window=self_attn.sliding_window,  # diff with Llama
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        o_proj_output = self_attn.o_proj(attn_output)
        
        if self.simulate:
            attn_output = _restore(o_proj_output, **self.o_proj_param[layer_idx])
        else:
            attn_output = self.sgx.restore(
                o_proj_output, 
                layer_idx, 
                "o_proj", 
                self_attn.o_proj.bias if hasattr(self_attn.o_proj, "bias") else None
            )
        
        return attn_output
    
    def _custom_mlp(self, layer_idx, x):
        mlp = self.model.layers[layer_idx].mlp
        up_output = mlp.up_proj(x)
        gate_output = mlp.gate_proj(x)
        if self.simulate:
            up_output = _restore(up_output, **self.up_proj_param[layer_idx])
            gate_output = _restore(gate_output, **self.gate_proj_param[layer_idx])
            
            if isinstance(mlp.act_fn, SiLUActivation):
                act_output = nn.functional.silu(gate_output)
            else:
                raise ValueError(f"Unsupported Activation: {type(mlp.act_fn)}")
        else:
            up_output = self.sgx.restore(
                up_output, 
                layer_idx, 
                "up_proj", 
                mlp.up_proj.bias if hasattr(mlp.up_proj, "bias") else None
            )
            gate_output = self.sgx.restore(
                gate_output, 
                layer_idx, 
                "gate_proj", 
                mlp.gate_proj.bias if hasattr(mlp.gate_proj, "bias") else None
            )
            
            if isinstance(mlp.act_fn, SiLUActivation):
                act_output = self.sgx.silu_activation(gate_output)
            else:
                raise ValueError(f"Unsupported Activation: {type(mlp.act_fn)}")
        
        output = act_output * up_output
        down_proj = mlp.down_proj(output)
        
        if self.simulate:
            down_proj = _restore(down_proj, **self.down_proj_param[layer_idx])
        else:
            down_proj = self.sgx.restore(
                down_proj, 
                layer_idx, 
                "down_proj", 
                mlp.down_proj.bias if hasattr(mlp.down_proj, "bias") else None
            )
            
        return down_proj
    
    def _custom_decoder_forward(
        self,
        decoder_layer_idx,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self._custom_layernorm(decoder_layer_idx, "input_layernorm", hidden_states)   
        
        # Self Attention
        hidden_states = self._custom_attn(
            decoder_layer_idx,
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self._custom_layernorm(decoder_layer_idx, "post_attention_layernorm", hidden_states)
        hidden_states = self._custom_mlp(decoder_layer_idx, hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[TransformersKwargs]
    ) -> CausalLMOutputWithPast:
        if not self.simulate:
            self.sgx.reset_time()
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache(config=self.model.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # It may already have been prepared by e.g. `generate`
        if not isinstance(causal_mask_mapping := attention_mask, dict):
            # Prepare mask arguments
            mask_kwargs = {
                "config": self.model.config,
                "input_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "position_ids": position_ids,
            }
            # Create the masks
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
            }
            # The sliding window alternating layers are not always activated depending on the config
            if self.model.has_sliding_layers:
                causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

        for i, decoder_layer in enumerate(self.model.layers[: self.model.config.num_hidden_layers]):
            # For the first decoder layer, we need to obfuscate the hidden_states
            hidden_states = self._custom_decoder_forward(
                i,
                hidden_states,
                attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                position_ids=position_ids,
                past_key_values=past_key_values,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **kwargs,
            )

        hidden_states = self._custom_layernorm(-1, "norm", hidden_states)
        
        transformer_outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )

        hidden_states = transformer_outputs.last_hidden_state
        
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        output = CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------")
        print(f"Foward time: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds")
        if not self.simulate:
            sgx_compute_time, all_time = self.sgx.get_exe_time()
            print(f"Sgx compute time: {sgx_compute_time:.6f} milliseconds")
            print(f"Sgx compute + data transfer time: {all_time:.6f} milliseconds")
        print(f"----------------------------------------------------")
        return output

class CustomLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config, obf_param, simulate=True):
        super().__init__(config) 
        self.simulate = simulate
        if self.simulate:
            self.q_proj_param = []
            for layer_param in obf_param["q_proj"]:
                self.q_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.k_proj_param = []
            for layer_param in obf_param["k_proj"]:
                self.k_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.v_proj_param = []
            for layer_param in obf_param["v_proj"]:
                self.v_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.o_proj_param = []
            for layer_param in obf_param["o_proj"]:
                self.o_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.gate_proj_param = []
            for layer_param in obf_param["gate_proj"]:
                self.gate_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.up_proj_param = []
            for layer_param in obf_param["up_proj"]:
                self.up_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.down_proj_param = []
            for layer_param in obf_param["down_proj"]:
                self.down_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
        else:
            self.sgx = sgx_groupcover.get_sgx_instance()
            
            if not self.sgx.prepare_obf_params(obf_param):
                raise RuntimeError("Failed to prepare obfuscation parameters in SGX")
            
    
    def _custom_layernorm(self, layer_idx: int, norm_type: str, hidden_states):
        if self.simulate:
            match norm_type:
                case "input_layernorm":
                    hidden_states = self.model.layers[layer_idx].input_layernorm(hidden_states)
                case "post_attention_layernorm":
                    hidden_states = self.model.layers[layer_idx].post_attention_layernorm(hidden_states)
                case "norm":
                    hidden_states = self.model.norm(hidden_states)
                case _:
                    raise ValueError(f"No such type norm: {norm_type}")
        else:
            hidden_states = self.sgx.norm(hidden_states, layer_idx, norm_type)
        return hidden_states

    def prepare_norm_params(self):
        if not self.sgx.prepare_norm_params(self.model):
            raise RuntimeError("Failed to prepare normalization parameters in SGX")
    
    def _custom_attn(
        self,
        layer_idx,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        self_attn = self.model.layers[layer_idx].self_attn
        
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self_attn.head_dim)

        query_states = self_attn.q_proj(hidden_states)
        key_states = self_attn.k_proj(hidden_states)
        value_states = self_attn.v_proj(hidden_states)
        
        if self.simulate:
            query_states = _restore(query_states, **self.q_proj_param[layer_idx])
            key_states = _restore(key_states, **self.k_proj_param[layer_idx])
            value_states = _restore(value_states, **self.v_proj_param[layer_idx])
        else:
            query_states = self.sgx.restore(
                query_states, 
                layer_idx,
                "q_proj", 
                self_attn.q_proj.bias if hasattr(self_attn.q_proj, "bias") else None
            )
            key_states = self.sgx.restore(
                key_states, 
                layer_idx,
                "k_proj", 
                self_attn.k_proj.bias if hasattr(self_attn.k_proj, "bias") else None
            )
            value_states = self.sgx.restore(
                value_states, 
                layer_idx,
                "v_proj", 
                self_attn.v_proj.bias if hasattr(self_attn.v_proj, "bias") else None
            )
        
        query_states = query_states.view(hidden_shape).transpose(1, 2)
        key_states = key_states.view(hidden_shape).transpose(1, 2)
        value_states = value_states.view(hidden_shape).transpose(1, 2)
        

        cos, sin = position_embeddings
        query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self_attn.layer_idx, cache_kwargs)

        attention_interface: Callable = transformers.models.llama.modeling_llama.eager_attention_forward
        if self_attn.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self_attn.config._attn_implementation]

        attn_output, _ = attention_interface(
            self_attn,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self_attn.training else self_attn.attention_dropout,
            scaling=self_attn.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        o_proj_output = self_attn.o_proj(attn_output)
        
        if self.simulate:
            attn_output = _restore(o_proj_output, **self.o_proj_param[layer_idx])
        else:
            attn_output = self.sgx.restore(
                o_proj_output, 
                layer_idx, 
                "o_proj", 
                self_attn.o_proj.bias if hasattr(self_attn.o_proj, "bias") else None
            )
        
        return attn_output
    
    def _custom_mlp(self, layer_idx, x):
        mlp = self.model.layers[layer_idx].mlp
        up_output = mlp.up_proj(x)
        gate_output = mlp.gate_proj(x)
        if self.simulate:
            up_output = _restore(up_output, **self.up_proj_param[layer_idx])
            gate_output = _restore(gate_output, **self.gate_proj_param[layer_idx])
            
            if isinstance(mlp.act_fn, SiLUActivation):
                act_output = nn.functional.silu(gate_output)
            else:
                raise ValueError(f"Unsupported Activation: {type(mlp.act_fn)}")
        else:
            up_output = self.sgx.restore(
                up_output, 
                layer_idx, 
                "up_proj", 
                mlp.up_proj.bias if hasattr(mlp.up_proj, "bias") else None
            )
            gate_output = self.sgx.restore(
                gate_output, 
                layer_idx, 
                "gate_proj", 
                mlp.gate_proj.bias if hasattr(mlp.gate_proj, "bias") else None
            )
            
            if isinstance(mlp.act_fn, SiLUActivation):
                act_output = self.sgx.silu_activation(gate_output)
            else:
                raise ValueError(f"Unsupported Activation: {type(mlp.act_fn)}")
        
        output = act_output * up_output
        down_proj = mlp.down_proj(output)
        
        if self.simulate:
            down_proj = _restore(down_proj, **self.down_proj_param[layer_idx])
        else:
            down_proj = self.sgx.restore(
                down_proj, 
                layer_idx, 
                "down_proj", 
                mlp.down_proj.bias if hasattr(mlp.down_proj, "bias") else None
            )
            
        return down_proj
    
    def _custom_decoder_forward(
        self,
        decoder_layer_idx,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self._custom_layernorm(decoder_layer_idx, "input_layernorm", hidden_states)
            
        # Self Attention
        hidden_states = self._custom_attn(
            decoder_layer_idx,
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self._custom_layernorm(decoder_layer_idx, "post_attention_layernorm", hidden_states)
        hidden_states = self._custom_mlp(decoder_layer_idx, hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[TransformersKwargs]
    ) -> CausalLMOutputWithPast:
        if not self.simulate:
            self.sgx.reset_time()
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache(config=self.model.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = create_causal_mask(
            config=self.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        hidden_states = inputs_embeds
        position_embeddings = self.model.rotary_emb(hidden_states, position_ids)
        
        for i, decoder_layer in enumerate(self.model.layers[: self.model.config.num_hidden_layers]):
            # For the first decoder layer, we need to obfuscate the hidden_states
            hidden_states = self._custom_decoder_forward(
                i,
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **kwargs,
            )

        hidden_states = self._custom_layernorm(-1, "norm", hidden_states)
        
        transformer_outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )

        hidden_states = transformer_outputs.last_hidden_state
        
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        output = CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------")
        print(f"Foward time: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds")
        if not self.simulate:
            sgx_compute_time, all_time = self.sgx.get_exe_time()
            print(f"Sgx compute time: {sgx_compute_time:.6f} milliseconds")
            print(f"Sgx compute + data transfer time: {all_time:.6f} milliseconds")
        print(f"----------------------------------------------------")
        return output
    
class CustomGemma3ForCausalLM(Gemma3ForCausalLM):
    def __init__(self, config, obf_param, simulate=True):
        super().__init__(config)
        self.simulate = simulate
        if self.simulate:
            self.q_proj_param = []
            for layer_param in obf_param["q_proj"]:
                self.q_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.k_proj_param = []
            for layer_param in obf_param["k_proj"]:
                self.k_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.v_proj_param = []
            for layer_param in obf_param["v_proj"]:
                self.v_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.o_proj_param = []
            for layer_param in obf_param["o_proj"]:
                self.o_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.gate_proj_param = []
            for layer_param in obf_param["gate_proj"]:
                self.gate_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.up_proj_param = []
            for layer_param in obf_param["up_proj"]:
                self.up_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.down_proj_param = []
            for layer_param in obf_param["down_proj"]:
                self.down_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
        else:
            self.sgx = sgx_groupcover.get_sgx_instance()
            
            if not self.sgx.prepare_obf_params(obf_param):
                raise RuntimeError("Failed to prepare obfuscation parameters in SGX")
            
    def prepare_norm_params(self):
        if not self.sgx.prepare_norm_params(self.model, True):
            raise RuntimeError("Failed to prepare normalization parameters in SGX")
            
    def _custom_layernorm(self, layer_idx: int, norm_type: str, hidden_states):
        if self.simulate:
            match norm_type:
                case "input_layernorm":
                    hidden_states = self.model.layers[layer_idx].input_layernorm(hidden_states)
                case "post_attention_layernorm":
                    hidden_states = self.model.layers[layer_idx].post_attention_layernorm(hidden_states)
                case "pre_feedforward_layernorm":
                    hidden_states = self.model.layers[layer_idx].pre_feedforward_layernorm(hidden_states)
                case "post_feedforward_layernorm":
                    hidden_states = self.model.layers[layer_idx].post_feedforward_layernorm(hidden_states)
                case "q_norm":
                    hidden_states = self.model.layers[layer_idx].self_attn.q_norm(hidden_states)
                case "k_norm":
                    hidden_states = self.model.layers[layer_idx].self_attn.k_norm(hidden_states)
                case "norm":
                    hidden_states = self.model.norm(hidden_states)
                case _:
                    raise ValueError(f"No such type norm: {norm_type}")
        else:
            hidden_states = self.sgx.norm(hidden_states, layer_idx, norm_type)
        return hidden_states
    
    def _custom_attn(
        self,
        layer_idx,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        self_attn = self.model.layers[layer_idx].self_attn
        
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self_attn.head_dim)

        query_states = self_attn.q_proj(hidden_states)
        key_states = self_attn.k_proj(hidden_states)
        value_states = self_attn.v_proj(hidden_states)
        
        if self.simulate:
            query_states = _restore(query_states, **self.q_proj_param[layer_idx])
            key_states = _restore(key_states, **self.k_proj_param[layer_idx])
            value_states = _restore(value_states, **self.v_proj_param[layer_idx])
            
            query_states = self._custom_layernorm(layer_idx, "q_norm", query_states.view(hidden_shape).transpose(1, 2))
            key_states = self._custom_layernorm(layer_idx, "k_norm", key_states.view(hidden_shape).transpose(1, 2))
        else:
            query_states = self.sgx.restore(
                query_states, 
                layer_idx,
                "q_proj", 
                self_attn.q_proj.bias if hasattr(self_attn.q_proj, "bias") else None
            )
            key_states = self.sgx.restore(
                key_states, 
                layer_idx,
                "k_proj", 
                self_attn.k_proj.bias if hasattr(self_attn.k_proj, "bias") else None
            )
            value_states = self.sgx.restore(
                value_states, 
                layer_idx,
                "v_proj", 
                self_attn.v_proj.bias if hasattr(self_attn.v_proj, "bias") else None
            )
            
            query_states = self._custom_layernorm(layer_idx, "q_norm", query_states.view(hidden_shape).transpose(1, 2))
            key_states = self._custom_layernorm(layer_idx, "k_norm", key_states.view(hidden_shape).transpose(1, 2))

        value_states = value_states.view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = transformers.models.gemma3.modeling_gemma3.apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self_attn.layer_idx, cache_kwargs)

        attention_interface: Callable = transformers.models.gemma3.modeling_gemma3.eager_attention_forward
        if self_attn.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self_attn.config._attn_implementation]

        attn_output, _ = attention_interface(
            self_attn,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self_attn.training else self_attn.attention_dropout,
            scaling=self_attn.scaling,
            sliding_window=self_attn.sliding_window,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        o_proj_output = self_attn.o_proj(attn_output)
        
        if self.simulate:
            attn_output = _restore(o_proj_output, **self.o_proj_param[layer_idx])
        else:
            attn_output = self.sgx.restore(
                o_proj_output, 
                layer_idx, 
                "o_proj", 
                self_attn.o_proj.bias if hasattr(self_attn.o_proj, "bias") else None
            )
        
        return attn_output
    
    def _custom_mlp(self, layer_idx, x):
        mlp = self.model.layers[layer_idx].mlp
        up_output = mlp.up_proj(x)
        gate_output = mlp.gate_proj(x)
        if self.simulate:
            up_output = _restore(up_output, **self.up_proj_param[layer_idx])
            gate_output = _restore(gate_output, **self.gate_proj_param[layer_idx])
            
            if isinstance(mlp.act_fn, GELUTanh):
                act_output = nn.functional.gelu(gate_output)
            else:
                raise ValueError(f"Unsupported Activation: {type(mlp.act_fn)}")
        else:
            up_output = self.sgx.restore(
                up_output, 
                layer_idx, 
                "up_proj", 
                mlp.up_proj.bias if hasattr(mlp.up_proj, "bias") else None
            )
            gate_output = self.sgx.restore(
                gate_output, 
                layer_idx, 
                "gate_proj", 
                mlp.gate_proj.bias if hasattr(mlp.gate_proj, "bias") else None
            )
            
            if isinstance(mlp.act_fn, GELUTanh):
                act_output = self.sgx.gelu_tanh_activation(gate_output)
            else:
                raise ValueError(f"Unsupported Activation: {type(mlp.act_fn)}")
        
        output = act_output * up_output
        down_proj = mlp.down_proj(output)
        
        if self.simulate:
            down_proj = _restore(down_proj, **self.down_proj_param[layer_idx])
        else:
            down_proj = self.sgx.restore(
                down_proj, 
                layer_idx, 
                "down_proj", 
                mlp.down_proj.bias if hasattr(mlp.down_proj, "bias") else None
            )
            
        return down_proj
    
    def _custom_decoder_forward(
        self,
        decoder_layer_idx,
        hidden_states: torch.Tensor,
        position_embeddings_global: torch.Tensor,
        position_embeddings_local: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self._custom_layernorm(decoder_layer_idx, "input_layernorm", hidden_states)
        
        # apply global RoPE to non-sliding layer only
        if self.model.layers[decoder_layer_idx].self_attn.is_sliding:
            position_embeddings = position_embeddings_local
        else:
            position_embeddings = position_embeddings_global
            
        # Self Attention
        hidden_states = self._custom_attn(
            decoder_layer_idx,
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        
        hidden_states = self._custom_layernorm(decoder_layer_idx, "post_attention_layernorm", hidden_states)
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self._custom_layernorm(decoder_layer_idx, "pre_feedforward_layernorm", hidden_states)
        hidden_states = self._custom_mlp(decoder_layer_idx, hidden_states)
        hidden_states = self._custom_layernorm(decoder_layer_idx, "post_feedforward_layernorm", hidden_states)
        hidden_states = residual + hidden_states
        
        return hidden_states
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[TransformersKwargs]
    ) -> CausalLMOutputWithPast:
        if not self.simulate:
            self.sgx.reset_time()
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.model.config.use_cache
        
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.model.gradient_checkpointing and self.model.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        if use_cache and past_key_values is None and not self.model.training:
            past_key_values = DynamicCache(config=self.model.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens,
                past_seen_tokens + inputs_embeds.shape[1],
                device=inputs_embeds.device,
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # It may already have been prepared by e.g. `generate`
        if not isinstance(causal_mask_mapping := attention_mask, dict):
            # Prepare mask arguments
            mask_kwargs = {
                "config": self.model.config,
                "input_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "position_ids": position_ids,
            }
            sliding_mask_kwargs = mask_kwargs.copy()

            if self.model.config.use_bidirectional_attention:
                mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
                sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.model.config.sliding_window)

            # Create the masks
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
                "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
            }

        # embed positions
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings_global = self.model.rotary_emb(hidden_states, position_ids)
        position_embeddings_local = self.model.rotary_emb_local(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        
        for i, decoder_layer in enumerate(self.model.layers[: self.model.config.num_hidden_layers]):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            
            # For the first decoder layer, we need to obfuscate the hidden_states
            hidden_states = self._custom_decoder_forward(
                i,
                hidden_states,
                position_embeddings_global=position_embeddings_global,
                position_embeddings_local=position_embeddings_local,
                attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                position_ids=position_ids,
                past_key_values=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                **kwargs,
            )

        hidden_states = self._custom_layernorm(-1, "norm", hidden_states)
        
        transformer_outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )

        hidden_states = transformer_outputs.last_hidden_state
        
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])
        if self.config.final_logit_softcapping is not None:
            logits = logits / self.config.final_logit_softcapping
            logits = torch.tanh(logits)
            logits = logits * self.config.final_logit_softcapping

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        outputs = CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------")
        print(f"Foward time: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds")
        if not self.simulate:
            sgx_compute_time, all_time = self.sgx.get_exe_time()
            print(f"Sgx compute time: {sgx_compute_time:.6f} milliseconds")
            print(f"Sgx compute + data transfer time: {all_time:.6f} milliseconds")
        print(f"----------------------------------------------------")
        return outputs
    
class CustomGPT2LMHeadModel(GPT2LMHeadModel):
    def __init__(self, config, obf_param, simulate=True):
        super().__init__(config) 
        self.simulate = simulate
        if self.simulate:
            self.q_proj_param = []
            for layer_param in obf_param["q_proj"]:
                self.q_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.k_proj_param = []
            for layer_param in obf_param["k_proj"]:
                self.k_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.v_proj_param = []
            for layer_param in obf_param["v_proj"]:
                self.v_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.o_proj_param = []
            for layer_param in obf_param["o_proj"]:
                self.o_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.c_fc_param = []
            for layer_param in obf_param["c_fc"]:
                self.c_fc_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
            self.c_proj_param = []
            for layer_param in obf_param["c_proj"]:
                self.c_proj_param.append({"permutation": layer_param["permutation"], "blocks": [torch.linalg.inv(block) for block in layer_param["blocks"]]})
        else:
            self.sgx = sgx_groupcover.get_sgx_instance()
            
            if not self.sgx.prepare_obf_params(obf_param):
                raise RuntimeError("Failed to prepare obfuscation parameters in SGX")
    
    def prepare_norm_params(self):
        if not self.sgx.prepare_norm_params(self.transformer):
            raise RuntimeError("Failed to prepare normalization parameters in SGX")
    
    def _custom_layernorm(self, layer_idx: int, norm_type: str, hidden_states):
        if self.simulate:
            match norm_type:
                case "ln_1":
                    hidden_states = self.transformer.h[layer_idx].ln_1(hidden_states)
                case "ln_2":
                    hidden_states = self.transformer.h[layer_idx].ln_2(hidden_states)
                case "ln_f":
                    hidden_states = self.transformer.ln_f(hidden_states)
                case _:
                    raise ValueError(f"No such norm: {norm_type}")
        else:
            hidden_states = self.sgx.norm(hidden_states, layer_idx, norm_type)
        return hidden_states
    
    def _custom_attn(
        self,
        layer_idx,
        hidden_states: Optional[tuple[torch.FloatTensor]],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        **kwargs,
    ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
        attn = self.transformer.h[layer_idx].attn
        
        is_cross_attention = encoder_hidden_states is not None
        if past_key_values is not None:
            if isinstance(past_key_values, EncoderDecoderCache):
                is_updated = past_key_values.is_updated.get(attn.layer_idx)
                if is_cross_attention:
                    # after the first generated id, we can subsequently re-use all key/value_layer from cache
                    curr_past_key_value = past_key_values.cross_attention_cache
                else:
                    curr_past_key_value = past_key_values.self_attention_cache
            else:
                curr_past_key_value = past_key_values

        if is_cross_attention:
            raise NotImplementedError("Cross attention is not supported yet")
        else:
            query_states, key_states, value_states = attn.c_attn(hidden_states).split(attn.split_size, dim=2)
            query_bias, key_bias, value_bias = attn.c_attn.bias.data.split(attn.split_size, dim=0)
            if self.simulate:
                query_states = _restore(query_states, **self.q_proj_param[layer_idx], bias=query_bias)
                key_states = _restore(key_states, **self.k_proj_param[layer_idx], bias=key_bias)
                value_states = _restore(value_states, **self.v_proj_param[layer_idx], bias=value_bias)
            else:
                query_states = self.sgx.restore(
                    query_states, 
                    layer_idx,
                    "q_proj", 
                    query_bias
                )
                key_states = self.sgx.restore(
                    key_states, 
                    layer_idx,
                    "k_proj", 
                    key_bias
                )
                value_states = self.sgx.restore(
                    value_states, 
                    layer_idx,
                    "v_proj", 
                    value_bias
                )
            
            shape_kv = (*key_states.shape[:-1], -1, attn.head_dim)
            key_states = key_states.view(shape_kv).transpose(1, 2)
            value_states = value_states.view(shape_kv).transpose(1, 2)

        shape_q = (*query_states.shape[:-1], -1, attn.head_dim)
        query_states = query_states.view(shape_q).transpose(1, 2)

        if (past_key_values is not None and not is_cross_attention) or (
            past_key_values is not None and is_cross_attention and not is_updated
        ):
            # save all key/value_layer to cache to be re-used for fast auto-regressive generation
            cache_position = cache_position if not is_cross_attention else None
            key_states, value_states = curr_past_key_value.update(
                key_states, value_states, attn.layer_idx, {"cache_position": cache_position}
            )
            # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
            if is_cross_attention:
                past_key_values.is_updated[attn.layer_idx] = True

        is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention

        using_eager = attn.config._attn_implementation == "eager"
        attention_interface: Callable = transformers.models.gpt2.modeling_gpt2.eager_attention_forward
        if attn.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[attn.config._attn_implementation]

        if using_eager and attn.reorder_and_upcast_attn:
            attn_output, attn_weights = attn._upcast_and_reordered_attn(
                query_states, key_states, value_states, attention_mask, head_mask
            )
        else:
            attn_output, attn_weights = attention_interface(
                attn,
                query_states,
                key_states,
                value_states,
                attention_mask,
                head_mask=head_mask,
                dropout=attn.attn_dropout.p if attn.training else 0.0,
                is_causal=is_causal,
                **kwargs,
            )
        
        attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
        c_proj_output = attn.c_proj(attn_output)

        if self.simulate:
            attn_output = _restore(c_proj_output, **self.o_proj_param[layer_idx], bias=attn.c_proj.bias.data)
        else:
            attn_output = self.sgx.restore(
                c_proj_output, 
                layer_idx, 
                "o_proj", 
                attn.c_proj.bias if hasattr(attn.c_proj, "bias") else None
            )
        
        attn_output = attn.resid_dropout(attn_output)

        return attn_output, attn_weights
    
    def _custom_mlp(self, layer_idx, hidden_states):
        mlp = self.transformer.h[layer_idx].mlp
        fc_output = mlp.c_fc(hidden_states)
        
        if self.simulate:
            hidden_states = _restore(fc_output, **self.c_fc_param[layer_idx], bias=mlp.c_fc.bias.data)
            
            if isinstance(mlp.act, NewGELUActivation):
                hidden_states = 0.5 * hidden_states * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (hidden_states + 0.044715 * torch.pow(hidden_states, 3.0))))
            else:
                raise NotImplementedError("Only NewGELUActivation is supported yet.")
        else:
            hidden_states = self.sgx.restore(
                fc_output,
                layer_idx,
                "c_fc",
                mlp.c_fc.bias
            )
            if isinstance(mlp.act, NewGELUActivation):
                hidden_states = self.sgx.new_gelu_activation(hidden_states)
            else:
                raise NotImplementedError("Only NewGELUActivation is supported yet.")
            
        c_proj_output = mlp.c_proj(hidden_states)
        
        if self.simulate:
            hidden_states = _restore(c_proj_output, **self.c_proj_param[layer_idx], bias=mlp.c_proj.bias.data)
        else:
            hidden_states = self.sgx.restore(
                c_proj_output,
                layer_idx,
                "c_proj",
                mlp.c_proj.bias
            )
        
        hidden_states = mlp.dropout(hidden_states)
        return hidden_states
    
    def _custom_decoder_forward(
        self,
        decoder_layer_idx,
        hidden_states: Optional[tuple[torch.FloatTensor]],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        **kwargs,
    ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states
        hidden_states = self._custom_layernorm(decoder_layer_idx, "ln_1", hidden_states)
        attn_output, self_attn_weights = self._custom_attn(
            decoder_layer_idx,
            hidden_states,
            past_key_values=past_key_values,
            cache_position=cache_position,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            **kwargs,
        )

        # residual connection
        hidden_states = attn_output + residual

        if encoder_hidden_states is not None:
            raise NotImplementedError("Cross attention is not supported yet.")

        residual = hidden_states
        hidden_states = self._custom_layernorm(decoder_layer_idx, "ln_2", hidden_states)
        feed_forward_hidden_states = self._custom_mlp(decoder_layer_idx, hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)
            if encoder_hidden_states is not None:
                outputs += (cross_attn_weights,)

        return outputs
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs,
    ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
        if not self.simulate:
            self.sgx.reset_time()
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        output_attentions = output_attentions if output_attentions is not None else self.transformer.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.transformer.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.transformer.config.use_cache
        return_dict = return_dict if return_dict is not None else self.transformer.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            self.transformer.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            batch_size = input_ids.shape[0]
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size = inputs_embeds.shape[0]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])

        if self.transformer.gradient_checkpointing and self.transformer.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
        if use_cache:
            if past_key_values is None:
                past_key_values = DynamicCache(config=self.transformer.config)
            elif isinstance(past_key_values, tuple):
                logger.warning_once(
                    "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
                    "You should pass an instance of `Cache` instead, e.g. "
                    "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
                )
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)

            if self.transformer.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
                past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.transformer.config))

        if inputs_embeds is None:
            inputs_embeds = self.transformer.wte(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        position_embeds = self.transformer.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)

        # Attention mask.
        # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
        if attention_mask is not None and attention_mask.ndim < 4:
            attention_mask = attention_mask.view(batch_size, -1)

        causal_mask = create_causal_mask(
            config=self.transformer.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        _use_sdpa = self.transformer._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
        if self.transformer.config.add_cross_attention and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            if _use_sdpa:
                encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
                    mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
                )
            elif self.transformer._attn_implementation != "flash_attention_2":
                encoder_attention_mask = self.transformer.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # head_mask has shape n_layer x batch x n_heads x N x N
        head_mask = self.transformer.get_head_mask(head_mask, self.transformer.config.n_layer)

        if token_type_ids is not None:
            token_type_embeds = self.transformer.wte(token_type_ids)
            hidden_states = hidden_states + token_type_embeds

        hidden_states = self.transformer.drop(hidden_states)

        output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)

        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.transformer.config.add_cross_attention else None
        all_hidden_states = () if output_hidden_states else None
        for i, block in enumerate(self.transformer.h):
            # Model parallel
            if self.transformer.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                if isinstance(head_mask, torch.Tensor):
                    head_mask = head_mask.to(hidden_states.device)
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            outputs = self._custom_decoder_forward(
                i,
                hidden_states,
                past_key_values if not (self.transformer.gradient_checkpointing and self.transformer.training) else None,
                cache_position,
                causal_mask,
                head_mask[i],
                encoder_hidden_states,  # as a positional argument for gradient checkpointing
                encoder_attention_mask=encoder_attention_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
                **kwargs,
            )

            hidden_states = outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[1],)
                if self.transformer.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (outputs[2],)

            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.transformer.model_parallel:
                for k, v in self.transformer.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.transformer.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

        hidden_states = self._custom_layernorm(-1, "ln_f", hidden_states)

        hidden_states = hidden_states.view(output_shape)
        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        past_key_values = past_key_values if use_cache else None
        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
                if v is not None
            )

        transformer_outputs = BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

        hidden_states = transformer_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            hidden_states = hidden_states.to(self.lm_head.weight.device)

        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            # Flatten the tokens
            loss = self.loss_function(
                logits,
                labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )

        if not return_dict:
            output = (logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        outputs = CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------")
        print(f"Foward time: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds")
        if not self.simulate:
            sgx_compute_time, all_time = self.sgx.get_exe_time()
            print(f"Sgx compute time: {sgx_compute_time:.6f} milliseconds")
            print(f"Sgx compute + data transfer time: {all_time:.6f} milliseconds")
        print(f"----------------------------------------------------")
        return outputs

    
def convert_to_custom_model(original_model, obf_param, simulate=True):
    """
    Convert model to custom model (support model: Qwen3ForSequenceClassification).
    Args:
        original_model: Original model.
        mask_list: List of masks.
        ratio_mask_list: List of ratio of masks.
        ratio_w_list: List of ratio of weights.
        permutation_list: List of permutations.
        simulate: Whether to simulate obfuscation or run in SGX. Default is True.
    Returns:
        custom_model: CustomXXX instance.
    """
    config = original_model.config
        
    if isinstance(original_model, Qwen3ForCausalLM):
        custom_model = CustomQwen3ForCausalLM(config, obf_param, simulate)
    elif isinstance(original_model, LlamaForCausalLM):
        custom_model = CustomLlamaForCausalLM(config, obf_param, simulate)
    elif isinstance(original_model, Gemma3ForCausalLM):
        custom_model = CustomGemma3ForCausalLM(config, obf_param, simulate)
    elif isinstance(original_model, GPT2LMHeadModel):
        custom_model = CustomGPT2LMHeadModel(config, obf_param, simulate)
    else:
        raise ValueError(f"Unsupported model type: {type(original_model)}")
    
    custom_model.load_state_dict(original_model.state_dict())
    
    device = next(original_model.parameters()).device
    custom_model = custom_model.to(device)
    
    if not simulate:
        custom_model.prepare_norm_params()
    
    return custom_model