from typing import Callable, Optional, Union
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import logging
import math
import time
from transformers import (
    Qwen3ForSequenceClassification, 
    Qwen3ForCausalLM, 
    LlamaForCausalLM, 
    LlamaForSequenceClassification,
    GPT2LMHeadModel,
    GPT2ForSequenceClassification,
    Gemma3TextForSequenceClassification,
    Gemma3ForCausalLM
)
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs
from transformers.modeling_utils import (
    ALL_ATTENTION_FUNCTIONS
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    SequenceClassifierOutputWithPast,
    CausalLMOutputWithPast,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions
)
from transformers.masking_utils import (
    create_causal_mask, 
    create_sliding_window_causal_mask
)
from transformers.models.gemma3.modeling_gemma3 import (
    _bidirectional_window_overlay,
    apply_rotary_pos_emb,
    eager_attention_forward
)
# Import SGX obfuscation module
from sgx_api import sgx_ours



# Custom formatter with colors for different log levels
class ColoredFormatter(logging.Formatter):
    # ANSI color codes
    COLORS = {
        logging.DEBUG: '\033[94m',    # Blue
        logging.INFO: '\033[92m',     # Green
        logging.WARNING: '\033[93m',  # Yellow
        logging.ERROR: '\033[91m',    # Red
        logging.CRITICAL: '\033[95m', # Purple
        'RESET': '\033[0m'            # Reset color
    }
    
    def format(self, record):
        # Get the color for this log level
        color = self.COLORS.get(record.levelno, self.COLORS['RESET'])
        # Format the record with color
        original_formatter = logging.Formatter(f'{color}[%(levelname)s]{self.COLORS["RESET"]} %(name)s: %(message)s')
        return original_formatter.format(record)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = ColoredFormatter()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


def _house_holder_cal_col(input: torch.Tensor, vectors: list[torch.Tensor]) -> torch.Tensor:
    last_col = 0
    for vector in vectors:
        dim_v = vector.shape[0]
        vector = vector.view(-1, 1)
        x = input[:, :, last_col:last_col+dim_v]
        input[:, :, last_col:last_col+dim_v] = x - 2 * x @ vector @ vector.t()
        last_col = last_col+dim_v
    return input

def _permutate_col(input: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
    return input[:, :, indices]

def _obfus_col(input: torch.Tensor, indices_list: list[torch.Tensor], vectors_list: list[list[torch.Tensor]]) -> torch.Tensor:
    r"""
    Obfuscate the input tensor.
    X = X @ H1 @ pi1 @ H2 @ pi2 @ ...
    
    Args:
        input: The input tensor to be obfuscated.
        indices_list: A list of index tensors, each of which specifies the permutation indices for a column.
        vectors_list: A list of list of vector tensors, each of which specifies the Householder reflection vectors for a column.
    
    Returns:
        The obfuscated tensor.
    """
    for vectors, indices in zip(vectors_list, indices_list):
        input = _house_holder_cal_col(input, vectors)
        input = _permutate_col(input, indices)
    return input

def _obfus_col_inv(input: torch.Tensor, indices_list: list[torch.Tensor], vectors_list: list[list[torch.Tensor]]) -> torch.Tensor:
    r"""
    Obfuscate the input tensor in reverse order.
    X = X @ pi1 @ H1 @ pi2 @ H2 @ ...
    
    Args:
        input: The input tensor to be obfuscated.
        indices_list: A list of index tensors, each of which specifies the permutation indices for a column.
        vectors_list: A list of list of vector tensors, each of which specifies the Householder reflection vectors for a column.
    
    Returns:
        The obfuscated tensor.
    """
    for i in range(len(indices_list)-1, -1, -1):
        input = _permutate_col(input, indices_list[i])
        input = _house_holder_cal_col(input, vectors_list[i])
    return input


class CustomQwen3ForSequenceClassification(Qwen3ForSequenceClassification):
    r"""
    Modified based on the GenericForSequenceClassification and Qwen3Model class.

    If you cannot run this code in your own environment, you can find the Qwen3Model 
    code from your local transformers library and rewrite its forward function as follows.
    """
    def __init__(self, config, v_list0: list[list[torch.Tensor]], indices_list0: list[torch.Tensor], 
                 v_list: list[list[torch.Tensor]], indices_list: list[torch.Tensor], optimized_stage: int, simulate: bool):
        super().__init__(config)
        
        self.v_list0 = v_list0
        self.indices_list0 = indices_list0
        self.v_list = v_list
        self.indices_list = indices_list
        self.optimized_stage = optimized_stage
        self.simulate = simulate
            
    
    def _custom_first_decoder_forward(
        self,
        decoder_layer,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        r"""
        Customized forward function for the first decoder layer.
        
        Output = (hidden_states + output_attn + output_mlp @ Pi0.T) @ Pi
        """
        residual = hidden_states
        hidden_states = decoder_layer.input_layernorm(hidden_states)
        # Self Attention
        hidden_states, _ = decoder_layer.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = decoder_layer.post_attention_layernorm(hidden_states)
        hidden_states = decoder_layer.mlp(hidden_states)
        if self.simulate:
            # residual + output_mlp @ Pi0.T
            hidden_states = residual + _obfus_col_inv(hidden_states, self.indices_list0, self.v_list0)
            hidden_states = _obfus_col(hidden_states, self.indices_list, self.v_list)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            # Use SGX for obfuscation
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_obfuscation(hidden_states, residual, self.optimized_stage)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(input_obf+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        return hidden_states
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> SequenceClassifierOutputWithPast:
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache(config=self.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # It may already have been prepared by e.g. `generate`
        if not isinstance(causal_mask_mapping := attention_mask, dict):
            # Prepare mask arguments
            mask_kwargs = {
                "config": self.config,
                "input_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "position_ids": position_ids,
            }
            # Create the masks
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
            }
            # The sliding window alternating layers are not always activated depending on the config
            if self.model.has_sliding_layers:
                causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

        for i, decoder_layer in enumerate(self.model.layers[: self.model.config.num_hidden_layers]):
            # For the first decoder layer, we need to obfuscate the hidden_states
            if i == 0:
                hidden_states = self._custom_first_decoder_forward(
                    decoder_layer,
                    hidden_states,
                    attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs,
                )
            else:
                hidden_states = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs,
                )  

        hidden_states = self.model.norm(hidden_states)
        transformer_outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )

        hidden_states = transformer_outputs.last_hidden_state

        logits = self.score(hidden_states)

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            last_non_pad_token = -1
        elif input_ids is not None:
            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
        else:
            last_non_pad_token = -1
            logger.warning_once(
                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
            )

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)


        output = SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------\nForward cost: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        return output

class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
    r"""
    Modified based on the Qwen3ForCausalLM and Qwen3Model class.

    If you cannot run this code in your own environment, you can find the Qwen3Model 
    code from your local transformers library and rewrite its forward function as follows.
    """
    def __init__(
        self, 
        config, 
        v_list0: list[list[torch.Tensor]], 
        indices_list0: list[torch.Tensor], 
        v_list: list[list[torch.Tensor]], 
        indices_list: list[torch.Tensor], 
        optimized_stage: int, 
        simulate: bool
    ):
        super().__init__(config)
        
        self.v_list0 = v_list0
        self.indices_list0 = indices_list0
        self.v_list = v_list
        self.indices_list = indices_list
        self.otp = None
        self.otp_logits = None
        self.optimized_stage = optimized_stage
        self.simulate = simulate
        
        if not self.simulate:
            return
            
        self.indices_list_inv = []
        
        for i in range(len(self.indices_list)):
            self.indices_list_inv.append(torch.argsort(self.indices_list[i]))
    
    def _custom_first_decoder_forward(
        self,
        decoder_layer,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        r"""
        Customized forward function for the first decoder layer.
        
        Output = (hidden_states + output_attn + output_mlp @ Pi0.T) @ Pi
        """
        residual = hidden_states
        hidden_states = decoder_layer.input_layernorm(hidden_states)
        # Self Attention
        hidden_states, _ = decoder_layer.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = decoder_layer.post_attention_layernorm(hidden_states)
        hidden_states = decoder_layer.mlp(hidden_states)

        if self.simulate:
            # residual + output_mlp @ Pi0.T
            hidden_states = residual + _obfus_col_inv(hidden_states, self.indices_list0, self.v_list0)
            hidden_states = _obfus_col(hidden_states, self.indices_list, self.v_list)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            # Use SGX for obfuscation
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_obfuscation(hidden_states, residual, self.optimized_stage)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(input_obf+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        return hidden_states
    
    def _recover_and_otp(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Recover the hidden states from the obfuscated states and do otp.
        
        hidden_states: (batch_size, logits_to_keep, hidden_size)
        """
        if self.simulate:
            hidden_states = hidden_states / self.model.norm.weight.data
            hidden_states = _obfus_col_inv(hidden_states, self.indices_list_inv, self.v_list)
            hidden_states = hidden_states + self.otp
            hidden_states = hidden_states * self.model.norm.weight.data
        else:
            hidden_states = hidden_states / self.model.norm.weight.data

            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_otp(hidden_states)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(otp+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
            
            hidden_states = hidden_states * self.model.norm.weight.data
            
        return hidden_states
    
    def _recover_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Recover the logits from the opt hidden states.
        """
        if self.simulate:
            hidden_states = hidden_states - self.otp_logits
            # replace the max value with 1, others with 0
            _, indices = torch.max(hidden_states, dim=-1, keepdim=True)
            logits = torch.zeros_like(hidden_states, device=hidden_states.device).scatter_(-1, indices, 1.0)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            start_time = time.perf_counter()
            logits = sgx_ours.perform_logits_recover(hidden_states)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(logits_recover+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        return logits
    
    def prepare_otp_params(self, batch_size: int, logits_to_keep: int=1):
        """
        Prepare the OTP params in SGX.
        
        Args:
            batch_size (int): The batch size.
            logits_to_keep (int, optional): The number of logits to keep. Defaults to 1.
        """
        # Initialize otp with random values
        otp = torch.randn((batch_size, logits_to_keep, self.config.hidden_size), device=self.device)
        otp_logits = self.lm_head(otp)
        otp = otp / self.model.norm.weight.data
        
        if self.simulate:
            self.otp = otp
            self.otp_logits = otp_logits
        else:
            if not sgx_ours.prepare_otp_params(otp, otp_logits):
                raise RuntimeError("Failed to initialize SGX OTP params")
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[TransformersKwargs]
    ) -> CausalLMOutputWithPast:
        """ 
        NOTE: In practice, you should implement an OTP manager inside SGX to generate and 
        manage OTP keys in advance, instead of generating them before each inference.
        
        When measuring performance, we omit this overhead because it can be fully implemented within SGX.
        """
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        self.prepare_otp_params(input_ids.shape[0], input_ids[:, slice_indices].shape[1])
            
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache(config=self.model.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # It may already have been prepared by e.g. `generate`
        if not isinstance(causal_mask_mapping := attention_mask, dict):
            # Prepare mask arguments
            mask_kwargs = {
                "config": self.model.config,
                "input_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "position_ids": position_ids,
            }
            # Create the masks
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
            }
            # The sliding window alternating layers are not always activated depending on the config
            if self.model.has_sliding_layers:
                causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

        for i, decoder_layer in enumerate(self.model.layers[: self.model.config.num_hidden_layers]):
            # For the first decoder layer, we need to obfuscate the hidden_states
            if i == 0:
                hidden_states = self._custom_first_decoder_forward(
                    decoder_layer,
                    hidden_states,
                    attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs,
                )
            else:
                hidden_states = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs,
                )  

        hidden_states = self.model.norm(hidden_states)
        transformer_outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )

        hidden_states = transformer_outputs.last_hidden_state
        
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        hidden_states = hidden_states[:, slice_indices, :]
        hidden_states = self._recover_and_otp(hidden_states)
        hidden_states = self.lm_head(hidden_states)
        logits = self._recover_logits(hidden_states)

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        output = CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------\nForward cost: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        return output

class CustomLlamaForSequenceClassification(LlamaForSequenceClassification):
    r"""
    Modified based on the GenericForSequenceClassification and LlamaModel class.

    If you cannot run this code in your own environment, you can find the LlamaModel 
    code from your local transformers library and rewrite its forward function as follows.
    """
    def __init__(self, config, v_list0: list[list[torch.Tensor]], indices_list0: list[torch.Tensor], 
                 v_list: list[list[torch.Tensor]], indices_list: list[torch.Tensor], optimized_stage: int, simulate: bool):
        super().__init__(config)
        
        self.v_list0 = v_list0
        self.indices_list0 = indices_list0
        self.v_list = v_list
        self.indices_list = indices_list
        self.optimized_stage = optimized_stage
        self.simulate = simulate
    
    def _custom_first_decoder_forward(
        self,
        decoder_layer,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        r"""
        Customized forward function for the first decoder layer.
        
        Output = (hidden_states + output_attn + output_mlp @ Pi0.T) @ Pi
        """
        residual = hidden_states
        hidden_states = decoder_layer.input_layernorm(hidden_states)
        # Self Attention
        hidden_states, _ = decoder_layer.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = decoder_layer.post_attention_layernorm(hidden_states)
        hidden_states = decoder_layer.mlp(hidden_states)
        if self.simulate:
            # residual + output_mlp @ Pi0.T
            hidden_states = residual + _obfus_col_inv(hidden_states, self.indices_list0, self.v_list0)
            hidden_states = _obfus_col(hidden_states, self.indices_list, self.v_list)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            # Use SGX for obfuscation
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_obfuscation(hidden_states, residual, self.optimized_stage)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(input_obf+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        return hidden_states
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> SequenceClassifierOutputWithPast:
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache(config=self.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = create_causal_mask(
            config=self.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        hidden_states = inputs_embeds
        position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

        for i, decoder_layer in enumerate(self.model.layers[: self.model.config.num_hidden_layers]):
            # For the first decoder layer, we need to obfuscate the hidden_states
            if i == 0:
                hidden_states = self._custom_first_decoder_forward(
                    decoder_layer,
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs,
                )
            else:
                hidden_states = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs,
                )  

        hidden_states = self.model.norm(hidden_states)
        transformer_outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )

        hidden_states = transformer_outputs.last_hidden_state

        logits = self.score(hidden_states)

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            last_non_pad_token = -1
        elif input_ids is not None:
            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
        else:
            last_non_pad_token = -1
            logger.warning_once(
                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
            )

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)


        output = SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------\nForward cost: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        return output

class CustomLlamaForCausalLM(LlamaForCausalLM):
    r"""
    Modified based on the LlamaForCausalLM and LlamaModel class.

    If you cannot run this code in your own environment, you can find the LlamaModel 
    code from your local transformers library and rewrite its forward function as follows.
    """
    def __init__(
        self, 
        config, 
        v_list0: list[list[torch.Tensor]], 
        indices_list0: list[torch.Tensor], 
        v_list: list[list[torch.Tensor]], 
        indices_list: list[torch.Tensor], 
        optimized_stage: int, 
        simulate: bool
    ):
        super().__init__(config)
        
        self.v_list0 = v_list0
        self.indices_list0 = indices_list0
        self.v_list = v_list
        self.indices_list = indices_list
        self.otp = None
        self.otp_logits = None
        self.optimized_stage = optimized_stage
        self.simulate = simulate
        
        if not self.simulate:
            return
            
        self.indices_list_inv = []
        
        for i in range(len(self.indices_list)):
            self.indices_list_inv.append(torch.argsort(self.indices_list[i]))
    
    def _custom_first_decoder_forward(
        self,
        decoder_layer,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        r"""
        Customized forward function for the first decoder layer.
        
        Output = (hidden_states + output_attn + output_mlp @ Pi0.T) @ Pi
        """
        residual = hidden_states
        hidden_states = decoder_layer.input_layernorm(hidden_states)
        # Self Attention
        hidden_states, _ = decoder_layer.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = decoder_layer.post_attention_layernorm(hidden_states)
        hidden_states = decoder_layer.mlp(hidden_states)
        # self.first_decoder_output = hidden_states
        if self.simulate:
            # residual + output_mlp @ Pi0.T
            hidden_states = residual + _obfus_col_inv(hidden_states, self.indices_list0, self.v_list0)
            hidden_states = _obfus_col(hidden_states, self.indices_list, self.v_list)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            # Use SGX for obfuscation
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_obfuscation(hidden_states, residual, self.optimized_stage)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(input_obf+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        return hidden_states
    
    def _recover_and_otp(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Recover the hidden states from the obfuscated states and do otp.
        
        hidden_states: (batch_size, logits_to_keep, hidden_size)
        """
        if self.simulate:
            hidden_states = hidden_states / self.model.norm.weight.data
            hidden_states = _obfus_col_inv(hidden_states, self.indices_list_inv, self.v_list)
            hidden_states = hidden_states + self.otp
            hidden_states = hidden_states * self.model.norm.weight.data
        else:
            hidden_states = hidden_states / self.model.norm.weight.data
            
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_otp(hidden_states)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(otp+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
            
            hidden_states = hidden_states * self.model.norm.weight.data
            
        return hidden_states
    
    def _recover_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Recover the logits from the opt hidden states.
        """
        if self.simulate:
            hidden_states = hidden_states - self.otp_logits
            # replace the max value with 1, others with 0
            _, indices = torch.max(hidden_states, dim=-1, keepdim=True)
            logits = torch.zeros_like(hidden_states, device=hidden_states.device).scatter_(-1, indices, 1.0)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            start_time = time.perf_counter()
            logits = sgx_ours.perform_logits_recover(hidden_states)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(logits_recover+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        return logits
    
    def prepare_otp_params(self, batch_size: int, logits_to_keep: int=1):
        """
        Prepare the OTP params in SGX.
        
        Args:
            batch_size (int): The batch size.
            logits_to_keep (int, optional): The number of logits to keep. Defaults to 1.
        """
        # Initialize otp with random values
        otp = torch.randn((batch_size, logits_to_keep, self.config.hidden_size), device=self.device)
        otp_logits = self.lm_head(otp)
        otp = otp / self.model.norm.weight.data
        
        if self.simulate:
            self.otp = otp
            self.otp_logits = otp_logits
        else:
            if not sgx_ours.prepare_otp_params(otp, otp_logits):
                raise RuntimeError("Failed to initialize SGX OTP params")
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[TransformersKwargs]
    ) -> CausalLMOutputWithPast:
        """ 
        NOTE: In practice, you should implement an OTP manager inside SGX to generate and 
        manage OTP keys in advance, instead of generating them before each inference.
        
        When measuring performance, we omit this overhead because it can be fully implemented within SGX.
        """
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        self.prepare_otp_params(input_ids.shape[0], input_ids[:, slice_indices].shape[1])
            
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache(config=self.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = create_causal_mask(
            config=self.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        hidden_states = inputs_embeds
        position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

        for i, decoder_layer in enumerate(self.model.layers[: self.model.config.num_hidden_layers]):
            # For the first decoder layer, we need to obfuscate the hidden_states
            if i == 0:
                hidden_states = self._custom_first_decoder_forward(
                    decoder_layer,
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs,
                )
            else:
                hidden_states = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs,
                )  

        hidden_states = self.model.norm(hidden_states)
        transformer_outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )

        hidden_states = transformer_outputs.last_hidden_state
        
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        hidden_states = hidden_states[:, slice_indices, :]
        hidden_states = self._recover_and_otp(hidden_states)
        hidden_states = self.lm_head(hidden_states)
        logits = self._recover_logits(hidden_states)

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        output = CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------\nForward cost: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        return output

class CustomGemma3TextForSequenceClassification(Gemma3TextForSequenceClassification):
    r"""
    Modified based on the GenericForSequenceClassification and Gemma3TextModel class.

    If you cannot run this code in your own environment, you can find the Gemma3TextModel 
    code from your local transformers library and rewrite its forward function as follows.
    """
    def __init__(self, config, v_list0: list[list[torch.Tensor]], indices_list0: list[torch.Tensor], 
                 v_list: list[list[torch.Tensor]], indices_list: list[torch.Tensor], optimized_stage: int, simulate: bool):
        super().__init__(config)
        
        self.v_list0 = v_list0
        self.indices_list0 = indices_list0
        self.v_list = v_list
        self.indices_list = indices_list
        self.optimized_stage = optimized_stage
        self.simulate = simulate
        
        self.o_proj_branchs = nn.ModuleList([
            None if i == 0 else nn.Linear(self.model.layers[i].self_attn.o_proj.in_features,
                                          self.model.layers[i].self_attn.o_proj.out_features,
                                          bias=self.model.layers[i].self_attn.o_proj.bias is not None)
            for i in range(self.config.num_hidden_layers)
        ])
        
        self.down_proj_branchs = nn.ModuleList([
            nn.Linear(self.model.layers[i].mlp.down_proj.in_features,
                      self.model.layers[i].mlp.down_proj.out_features,
                      bias=self.model.layers[i].mlp.down_proj.bias is not None)
            for i in range(self.config.num_hidden_layers)
        ])

    def _mlp(self, mlp, x):
        down_proj_input = mlp.act_fn(mlp.gate_proj(x)) * mlp.up_proj(x)
        down_proj = mlp.down_proj(down_proj_input)
        return down_proj, down_proj_input
    
    def _attn(
        self,
        attn_layer,
        hidden_states: torch.Tensor,
        position_embeddings: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, attn_layer.head_dim)

        query_states = attn_layer.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = attn_layer.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = attn_layer.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        query_states = attn_layer.q_norm(query_states)
        key_states = attn_layer.k_norm(key_states)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, attn_layer.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if attn_layer.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[attn_layer.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            attn_layer,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=attn_layer.attention_dropout if attn_layer.training else 0.0,
            scaling=attn_layer.scaling,
            sliding_window=attn_layer.sliding_window,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        o_proj_input = attn_output
        attn_output = attn_layer.o_proj(attn_output)
        return attn_output, attn_weights, o_proj_input
    
    def _custom_first_decoder_forward(
        self,
        decoder_layer,
        hidden_states: torch.Tensor,
        position_embeddings_global: torch.Tensor,
        position_embeddings_local: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
        r"""
        Customized forward function for the first decoder layer.
        
        Output = (hidden_states + output_attn + output_mlp @ Pi0.T) @ Pi
        """
        residual = hidden_states

        hidden_states = decoder_layer.input_layernorm(hidden_states)

        # apply global RoPE to non-sliding layer only
        if decoder_layer.self_attn.is_sliding:
            position_embeddings = position_embeddings_local
        else:
            position_embeddings = position_embeddings_global

        hidden_states, self_attn_weights = decoder_layer.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )
        
        hidden_states = decoder_layer.post_attention_layernorm(hidden_states)

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = decoder_layer.pre_feedforward_layernorm(hidden_states)
        
        # get the inputs of down_proj
        hidden_states, down_proj_input = self._mlp(decoder_layer.mlp, hidden_states)
        
        # calculate RMS(down_proj_input @ W_down_proj @ Q)
        rms_input = self.down_proj_branchs[0](down_proj_input)
        rms = torch.rsqrt(rms_input.pow(2).mean(-1, keepdim=True) + decoder_layer.post_feedforward_layernorm.eps)
        hidden_states = hidden_states * rms
        
        if self.simulate:
            # residual + output_mlp @ Pi0.T
            hidden_states = residual + _obfus_col_inv(hidden_states, self.indices_list0, self.v_list0)
            hidden_states = _obfus_col(hidden_states, self.indices_list, self.v_list)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            # Use SGX for obfuscation
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_obfuscation(hidden_states, residual, self.optimized_stage)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(input_obf+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs
    
    def _custom_remaining_decoder_forward(
        self,
        decoder_layer_idx,
        hidden_states: torch.Tensor,
        position_embeddings_global: torch.Tensor,
        position_embeddings_local: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
        decoder_layer = self.model.layers[decoder_layer_idx]
        
        residual = hidden_states

        hidden_states = decoder_layer.input_layernorm(hidden_states)

        # apply global RoPE to non-sliding layer only
        if decoder_layer.self_attn.is_sliding:
            position_embeddings = position_embeddings_local
        else:
            position_embeddings = position_embeddings_global

        hidden_states, self_attn_weights, o_proj_input = self._attn(
            attn_layer=decoder_layer.self_attn,
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        # calculate RMS(o_proj_input @ W_o_proj @ Q)
        rms_input = self.o_proj_branchs[decoder_layer_idx](o_proj_input)
        rms = torch.rsqrt(rms_input.pow(2).mean(-1, keepdim=True) + decoder_layer.post_attention_layernorm.eps)
        hidden_states = hidden_states * rms
    
        hidden_states = residual + hidden_states

        residual = hidden_states
        
        hidden_states = decoder_layer.pre_feedforward_layernorm(hidden_states)
        
        # get the inputs of down_proj
        hidden_states, down_proj_input = self._mlp(decoder_layer.mlp, hidden_states)
        
        # calculate RMS(down_proj_input @ W_down_proj @ Q)
        rms_input = self.down_proj_branchs[decoder_layer_idx](down_proj_input)
        rms = torch.rsqrt(rms_input.pow(2).mean(-1, keepdim=True) + decoder_layer.post_feedforward_layernorm.eps)
        hidden_states = hidden_states * rms
        
        hidden_states = residual + hidden_states
        
        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> SequenceClassifierOutputWithPast:
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.model.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.model.gradient_checkpointing and self.model.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        if use_cache and past_key_values is None and not self.model.training:
            past_key_values = DynamicCache(config=self.model.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens,
                past_seen_tokens + inputs_embeds.shape[1],
                device=inputs_embeds.device,
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # It may already have been prepared by e.g. `generate`
        if not isinstance(causal_mask_mapping := attention_mask, dict):
            # Prepare mask arguments
            mask_kwargs = {
                "config": self.model.config,
                "input_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "position_ids": position_ids,
            }
            sliding_mask_kwargs = mask_kwargs.copy()

            if self.model.config.use_bidirectional_attention:
                mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
                sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.model.config.sliding_window)

            # Create the masks
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
                "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
            }

        # embed positions
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings_global = self.model.rotary_emb(hidden_states, position_ids)
        position_embeddings_local = self.model.rotary_emb_local(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for i, decoder_layer in enumerate(self.model.layers[: self.model.config.num_hidden_layers]):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if i == 0:
                layer_outputs = self._custom_first_decoder_forward(
                    decoder_layer,
                    hidden_states,
                    position_embeddings_global=position_embeddings_global,
                    position_embeddings_local=position_embeddings_local,
                    attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    **kwargs,
                )
            else:
                layer_outputs = self._custom_remaining_decoder_forward(
                    i,
                    hidden_states,
                    position_embeddings_global=position_embeddings_global,
                    position_embeddings_local=position_embeddings_local,
                    attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    **kwargs,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.model.norm(hidden_states)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        transformer_outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

        hidden_states = transformer_outputs.last_hidden_state

        logits = self.score(hidden_states)

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            last_non_pad_token = -1
        elif input_ids is not None:
            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
        else:
            last_non_pad_token = -1
            logger.warning_once(
                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
            )

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

        output = SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------\nForward cost: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        return output
        
class CustomGemma3ForCausalLM(Gemma3ForCausalLM):
    r"""
    Modified based on the Gemma3ForCausalLM and Gemma3TextModel class.

    If you cannot run this code in your own environment, you can find the Gemma3TextModel 
    code from your local transformers library and rewrite its forward function as follows.
    """
    def __init__(
        self, 
        config, 
        v_list0: list[list[torch.Tensor]], 
        indices_list0: list[torch.Tensor], 
        v_list: list[list[torch.Tensor]], 
        indices_list: list[torch.Tensor],
        optimized_stage: int, 
        simulate: bool
    ):
        super().__init__(config)
        
        self.v_list0 = v_list0
        self.indices_list0 = indices_list0
        self.v_list = v_list
        self.indices_list = indices_list
        self.otp = None
        self.otp_logits = None
        self.optimized_stage = optimized_stage
        self.simulate = simulate
        
        self.o_proj_branchs = nn.ModuleList([
            None if i == 0 else nn.Linear(self.model.layers[i].self_attn.o_proj.in_features,
                                          self.model.layers[i].self_attn.o_proj.out_features,
                                          bias=self.model.layers[i].self_attn.o_proj.bias is not None)
            for i in range(self.config.num_hidden_layers)
        ])
        
        self.down_proj_branchs = nn.ModuleList([
            nn.Linear(self.model.layers[i].mlp.down_proj.in_features,
                      self.model.layers[i].mlp.down_proj.out_features,
                      bias=self.model.layers[i].mlp.down_proj.bias is not None)
            for i in range(self.config.num_hidden_layers)
        ])
        
        if not self.simulate:
            return
            
        self.indices_list_inv = []
        
        for i in range(len(self.indices_list)):
            self.indices_list_inv.append(torch.argsort(self.indices_list[i]))

    def _mlp(self, mlp, x):
        down_proj_input = mlp.act_fn(mlp.gate_proj(x)) * mlp.up_proj(x)
        down_proj = mlp.down_proj(down_proj_input)
        return down_proj, down_proj_input
    
    def _attn(
        self,
        attn_layer,
        hidden_states: torch.Tensor,
        position_embeddings: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, attn_layer.head_dim)

        query_states = attn_layer.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = attn_layer.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = attn_layer.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        query_states = attn_layer.q_norm(query_states)
        key_states = attn_layer.k_norm(key_states)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, attn_layer.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if attn_layer.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[attn_layer.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            attn_layer,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=attn_layer.attention_dropout if attn_layer.training else 0.0,
            scaling=attn_layer.scaling,
            sliding_window=attn_layer.sliding_window,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        o_proj_input = attn_output
        attn_output = attn_layer.o_proj(attn_output)
        return attn_output, attn_weights, o_proj_input
    
    def _custom_first_decoder_forward(
        self,
        decoder_layer,
        hidden_states: torch.Tensor,
        position_embeddings_global: torch.Tensor,
        position_embeddings_local: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
        r"""
        Customized forward function for the first decoder layer.
        
        Output = (hidden_states + output_attn + output_mlp @ Pi0.T) @ Pi
        """
        residual = hidden_states

        hidden_states = decoder_layer.input_layernorm(hidden_states)

        # apply global RoPE to non-sliding layer only
        if decoder_layer.self_attn.is_sliding:
            position_embeddings = position_embeddings_local
        else:
            position_embeddings = position_embeddings_global

        hidden_states, self_attn_weights = decoder_layer.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )
        
        hidden_states = decoder_layer.post_attention_layernorm(hidden_states)

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = decoder_layer.pre_feedforward_layernorm(hidden_states)
        
        # get the inputs of down_proj
        hidden_states, down_proj_input = self._mlp(decoder_layer.mlp, hidden_states)
        
        # calculate RMS(down_proj_input @ W_down_proj @ Q)
        rms_input = self.down_proj_branchs[0](down_proj_input)
        rms = torch.rsqrt(rms_input.pow(2).mean(-1, keepdim=True) + decoder_layer.post_feedforward_layernorm.eps)
        hidden_states = hidden_states * rms
        
        if self.simulate:
            # residual + output_mlp @ Pi0.T
            hidden_states = residual + _obfus_col_inv(hidden_states, self.indices_list0, self.v_list0)
            hidden_states = _obfus_col(hidden_states, self.indices_list, self.v_list)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            # Use SGX for obfuscation
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_obfuscation(hidden_states, residual, self.optimized_stage)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(input_obf+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs
    
    def _custom_remaining_decoder_forward(
        self,
        decoder_layer_idx,
        hidden_states: torch.Tensor,
        position_embeddings_global: torch.Tensor,
        position_embeddings_local: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
        decoder_layer = self.model.layers[decoder_layer_idx]
        
        residual = hidden_states

        hidden_states = decoder_layer.input_layernorm(hidden_states)

        # apply global RoPE to non-sliding layer only
        if decoder_layer.self_attn.is_sliding:
            position_embeddings = position_embeddings_local
        else:
            position_embeddings = position_embeddings_global

        hidden_states, self_attn_weights, o_proj_input = self._attn(
            attn_layer=decoder_layer.self_attn,
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        # calculate RMS(o_proj_input @ W_o_proj @ Q)
        rms_input = self.o_proj_branchs[decoder_layer_idx](o_proj_input)
        rms = torch.rsqrt(rms_input.pow(2).mean(-1, keepdim=True) + decoder_layer.post_attention_layernorm.eps)
        hidden_states = hidden_states * rms
    
        hidden_states = residual + hidden_states

        residual = hidden_states
        
        hidden_states = decoder_layer.pre_feedforward_layernorm(hidden_states)
        
        # get the inputs of down_proj
        hidden_states, down_proj_input = self._mlp(decoder_layer.mlp, hidden_states)
        
        # calculate RMS(down_proj_input @ W_down_proj @ Q)
        rms_input = self.down_proj_branchs[decoder_layer_idx](down_proj_input)
        rms = torch.rsqrt(rms_input.pow(2).mean(-1, keepdim=True) + decoder_layer.post_feedforward_layernorm.eps)
        hidden_states = hidden_states * rms
        
        hidden_states = residual + hidden_states
        
        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs
    
    def _recover_and_otp(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Recover the hidden states from the obfuscated states and do otp.
        
        hidden_states: (batch_size, logits_to_keep, hidden_size)
        """
        if self.simulate:
            hidden_states = hidden_states / (self.model.norm.weight.data + 1.)
            hidden_states = _obfus_col_inv(hidden_states, self.indices_list_inv, self.v_list)
            hidden_states = hidden_states + self.otp
            hidden_states = hidden_states * (self.model.norm.weight.data + 1.)
        else:
            hidden_states = hidden_states / (self.model.norm.weight.data + 1.)
            
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_otp(hidden_states)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(otp+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
            
            hidden_states = hidden_states * (self.model.norm.weight.data + 1.)
            
        return hidden_states
    
    def _recover_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Recover the logits from the opt hidden states.
        """
        if self.simulate:
            hidden_states = hidden_states - self.otp_logits
            # replace the max value with 1, others with 0
            _, indices = torch.max(hidden_states, dim=-1, keepdim=True)
            logits = torch.zeros_like(hidden_states, device=hidden_states.device).scatter_(-1, indices, 1.0)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            start_time = time.perf_counter()
            logits = sgx_ours.perform_logits_recover(hidden_states)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(logits_recover+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        return logits
    
    def prepare_otp_params(self, batch_size: int, logits_to_keep: int=1):
        """
        Prepare the OTP params in SGX.
        
        Args:
            batch_size (int): The batch size.
            logits_to_keep (int, optional): The number of logits to keep. Defaults to 1.
        """
        # Initialize otp with random values
        otp = torch.randn((batch_size, logits_to_keep, self.config.hidden_size), device=self.device)
        otp_logits = self.lm_head(otp)
        otp = otp / (self.model.norm.weight.data + 1.)
        
        if self.simulate:
            self.otp = otp
            self.otp_logits = otp_logits
        else:
            if not sgx_ours.prepare_otp_params(otp, otp_logits):
                raise RuntimeError("Failed to initialize SGX OTP params")
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        """ 
        NOTE: In practice, you should implement an OTP manager inside SGX to generate and 
        manage OTP keys in advance, instead of generating them before each inference.
        
        When measuring performance, we omit this overhead because it can be fully implemented within SGX.
        """
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        self.prepare_otp_params(input_ids.shape[0], input_ids[:, slice_indices].shape[1])
        
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.model.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.model.gradient_checkpointing and self.model.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        if use_cache and past_key_values is None and not self.model.training:
            past_key_values = DynamicCache(config=self.model.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens,
                past_seen_tokens + inputs_embeds.shape[1],
                device=inputs_embeds.device,
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # It may already have been prepared by e.g. `generate`
        if not isinstance(causal_mask_mapping := attention_mask, dict):
            # Prepare mask arguments
            mask_kwargs = {
                "config": self.model.config,
                "input_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "position_ids": position_ids,
            }
            sliding_mask_kwargs = mask_kwargs.copy()

            if self.model.config.use_bidirectional_attention:
                mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
                sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.model.config.sliding_window)

            # Create the masks
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
                "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
            }

        # embed positions
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings_global = self.model.rotary_emb(hidden_states, position_ids)
        position_embeddings_local = self.model.rotary_emb_local(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for i, decoder_layer in enumerate(self.model.layers[: self.model.config.num_hidden_layers]):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if i == 0:
                layer_outputs = self._custom_first_decoder_forward(
                    decoder_layer,
                    hidden_states,
                    position_embeddings_global=position_embeddings_global,
                    position_embeddings_local=position_embeddings_local,
                    attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    **kwargs,
                )
            else:
                layer_outputs = self._custom_remaining_decoder_forward(
                    i,
                    hidden_states,
                    position_embeddings_global=position_embeddings_global,
                    position_embeddings_local=position_embeddings_local,
                    attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    **kwargs,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.model.norm(hidden_states)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        transformer_outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

        hidden_states = transformer_outputs.last_hidden_state

        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        hidden_states = hidden_states[:, slice_indices, :]
        hidden_states = self._recover_and_otp(hidden_states)
        hidden_states = self.lm_head(hidden_states)
        logits = self._recover_logits(hidden_states)
        
        # We only return one-hot encoding，the `tanh` does not affect the result
        if self.config.final_logit_softcapping is not None:
            logits = logits / self.config.final_logit_softcapping
            logits = torch.tanh(logits)
            logits = logits * self.config.final_logit_softcapping

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
            
        output = CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------\nForward cost: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        return output
    
class CustomGPT2ForSequenceClassification(GPT2ForSequenceClassification):
    r"""
    Modified based on the GPT2ForSequenceClassification and GPT2Model class.

    If you cannot run this code in your own environment, you can find the GPT2Model 
    code from your local transformers library and rewrite its forward function as follows.
    """
    def __init__(self, config, v_list0: list[list[torch.Tensor]], indices_list0: list[torch.Tensor], 
                 v_list: list[list[torch.Tensor]], indices_list: list[torch.Tensor], optimized_stage: int, simulate: bool):
        super().__init__(config)
        
        self.v_list0 = v_list0
        self.indices_list0 = indices_list0
        self.v_list = v_list
        self.indices_list = indices_list
        self.optimized_stage = optimized_stage
        self.simulate = simulate
    
    def _custom_first_decoder_forward(
        self,
        decoder_layer,
        hidden_states: Optional[tuple[torch.FloatTensor]],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        **kwargs,
    ):
        r"""
        Customized forward function for the first decoder layer.
        
        Output = (hidden_states + output_attn + output_mlp @ Pi0.T) @ Pi
        """
        residual = hidden_states
        hidden_states = decoder_layer.ln_1(hidden_states)
        attn_output, self_attn_weights = decoder_layer.attn(
            hidden_states,
            past_key_values=past_key_values,
            cache_position=cache_position,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            **kwargs,
        )
        # residual connection
        hidden_states = attn_output + residual

        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            if not hasattr(decoder_layer, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {decoder_layer} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            residual = hidden_states
            hidden_states = decoder_layer.ln_cross_attn(hidden_states)
            cross_attn_output, cross_attn_weights = decoder_layer.crossattention(
                hidden_states,
                past_key_values=past_key_values,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            # residual connection
            hidden_states = residual + cross_attn_output

        residual = hidden_states
        hidden_states = decoder_layer.ln_2(hidden_states)
        hidden_states = decoder_layer.mlp(hidden_states)

        if self.simulate:
            # residual + output_mlp @ Pi0.T
            hidden_states = residual + _obfus_col_inv(hidden_states, self.indices_list0, self.v_list0)
            hidden_states = _obfus_col(hidden_states, self.indices_list, self.v_list)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            # Use SGX for obfuscation
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_obfuscation(hidden_states, residual, self.optimized_stage)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(input_obf+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)
            if encoder_hidden_states is not None:
                outputs += (cross_attn_weights,)
        return outputs
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        **kwargs,
    ) -> Union[tuple, SequenceClassifierOutputWithPast]:
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        output_attentions = output_attentions if output_attentions is not None else self.transformer.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.transformer.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.transformer.config.use_cache
        return_dict = return_dict if return_dict is not None else self.transformer.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            self.transformer.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            batch_size = input_ids.shape[0]
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size = inputs_embeds.shape[0]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])

        if self.transformer.gradient_checkpointing and self.transformer.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
        if use_cache:
            if past_key_values is None:
                past_key_values = DynamicCache(config=self.transformer.config)
            elif isinstance(past_key_values, tuple):
                logger.warning_once(
                    "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
                    "You should pass an instance of `Cache` instead, e.g. "
                    "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
                )
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)

            if self.transformer.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
                past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.transformer.config))

        if inputs_embeds is None:
            inputs_embeds = self.transformer.wte(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        position_embeds = self.transformer.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)

        # Attention mask.
        # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
        if attention_mask is not None and attention_mask.ndim < 4:
            attention_mask = attention_mask.view(batch_size, -1)

        causal_mask = create_causal_mask(
            config=self.transformer.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        _use_sdpa = self.transformer._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
        if self.transformer.config.add_cross_attention and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            if _use_sdpa:
                encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
                    mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
                )
            elif self.transformer._attn_implementation != "flash_attention_2":
                encoder_attention_mask = self.transformer.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # head_mask has shape n_layer x batch x n_heads x N x N
        head_mask = self.transformer.get_head_mask(head_mask, self.transformer.config.n_layer)

        if token_type_ids is not None:
            token_type_embeds = self.transformer.wte(token_type_ids)
            hidden_states = hidden_states + token_type_embeds

        hidden_states = self.transformer.drop(hidden_states)

        output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)

        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.transformer.config.add_cross_attention else None
        all_hidden_states = () if output_hidden_states else None
        for i, block in enumerate(self.transformer.h):
            # Model parallel
            if self.transformer.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                if isinstance(head_mask, torch.Tensor):
                    head_mask = head_mask.to(hidden_states.device)
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if i == 0:
                outputs = self._custom_first_decoder_forward(
                    block,
                    hidden_states,
                    past_key_values if not (self.transformer.gradient_checkpointing and self.transformer.training) else None,
                    cache_position,
                    causal_mask,
                    head_mask[i],
                    encoder_hidden_states,  # as a positional argument for gradient checkpointing
                    encoder_attention_mask=encoder_attention_mask,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    **kwargs,
                )
            else:
                outputs = block(
                    hidden_states,
                    past_key_values if not (self.transformer.gradient_checkpointing and self.transformer.training) else None,
                    cache_position,
                    causal_mask,
                    head_mask[i],
                    encoder_hidden_states,  # as a positional argument for gradient checkpointing
                    encoder_attention_mask=encoder_attention_mask,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    **kwargs,
                )

            hidden_states = outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[1],)
                if self.transformer.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (outputs[2],)

            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.transformer.model_parallel:
                for k, v in self.transformer.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.transformer.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

        hidden_states = self.transformer.ln_f(hidden_states)

        hidden_states = hidden_states.view(output_shape)
        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        past_key_values = past_key_values if use_cache else None
        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
                if v is not None
            )

        transformer_outputs = BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

        hidden_states = transformer_outputs[0]
        logits = self.score(hidden_states)

        if input_ids is not None:
            batch_size, sequence_length = input_ids.shape[:2]
        else:
            batch_size, sequence_length = inputs_embeds.shape[:2]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            last_non_pad_token = -1
        elif input_ids is not None:
            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
        else:
            last_non_pad_token = -1
            logger.warning_once(
                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
            )

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(pooled_logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(pooled_logits, labels)
        if not return_dict:
            output = (pooled_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        outputs = SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------\nForward cost: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        return outputs
        
class CustomGPT2LMHeadModel(GPT2LMHeadModel):
    r"""
    Modified based on the GPT2LMHeadModel and GPT2Model class.

    If you cannot run this code in your own environment, you can find the GPT2Model 
    code from your local transformers library and rewrite its forward function as follows.
    """
    def __init__(self, config, v_list0: list[list[torch.Tensor]], indices_list0: list[torch.Tensor], 
                 v_list: list[list[torch.Tensor]], indices_list: list[torch.Tensor], optimized_stage: int, simulate: bool):
        super().__init__(config)
        
        self.v_list0 = v_list0
        self.indices_list0 = indices_list0
        self.v_list = v_list
        self.indices_list = indices_list
        self.otp = None
        self.otp_logits = None
        self.optimized_stage = optimized_stage
        self.simulate = simulate
        
        if not self.simulate:
            return
            
        self.indices_list_inv = []
        
        for i in range(len(self.indices_list)):
            self.indices_list_inv.append(torch.argsort(self.indices_list[i]))
    
    def _custom_first_decoder_forward(
        self,
        decoder_layer,
        hidden_states: Optional[tuple[torch.FloatTensor]],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        **kwargs,
    ):
        r"""
        Customized forward function for the first decoder layer.
        
        Output = (hidden_states + output_attn + output_mlp @ Pi0.T) @ Pi
        """
        residual = hidden_states
        hidden_states = decoder_layer.ln_1(hidden_states)
        attn_output, self_attn_weights = decoder_layer.attn(
            hidden_states,
            past_key_values=past_key_values,
            cache_position=cache_position,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            **kwargs,
        )
        # residual connection
        hidden_states = attn_output + residual

        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            if not hasattr(decoder_layer, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {decoder_layer} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            residual = hidden_states
            hidden_states = decoder_layer.ln_cross_attn(hidden_states)
            cross_attn_output, cross_attn_weights = decoder_layer.crossattention(
                hidden_states,
                past_key_values=past_key_values,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            # residual connection
            hidden_states = residual + cross_attn_output

        residual = hidden_states
        hidden_states = decoder_layer.ln_2(hidden_states)
        hidden_states = decoder_layer.mlp(hidden_states)

        if self.simulate:
            # residual + output_mlp @ Pi0.T
            hidden_states = residual + _obfus_col_inv(hidden_states, self.indices_list0, self.v_list0)
            hidden_states = _obfus_col(hidden_states, self.indices_list, self.v_list)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            # Use SGX for obfuscation
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_obfuscation(hidden_states, residual, self.optimized_stage)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(input_obf+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)
            if encoder_hidden_states is not None:
                outputs += (cross_attn_weights,)
        return outputs
        
    def _recover_and_otp(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Recover the hidden states from the obfuscated states and do otp.
        
        hidden_states: (batch_size, logits_to_keep, hidden_size)
        """
        if self.simulate:
            hidden_states = hidden_states / self.transformer.ln_f.weight.data
            hidden_states = _obfus_col_inv(hidden_states, self.indices_list_inv, self.v_list)
            hidden_states = hidden_states + self.otp
            hidden_states = hidden_states * self.transformer.ln_f.weight.data
        else:
            hidden_states = hidden_states / self.transformer.ln_f.weight.data

            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            start_time = time.perf_counter()
            hidden_states = sgx_ours.perform_otp(hidden_states)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(otp+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
            
            hidden_states = hidden_states * self.transformer.ln_f.weight.data
            
        return hidden_states
    
    def _recover_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Recover the logits from the opt hidden states.
        """
        if self.simulate:
            hidden_states = hidden_states - self.otp_logits
            # replace the max value with 1, others with 0
            _, indices = torch.max(hidden_states, dim=-1, keepdim=True)
            logits = torch.zeros_like(hidden_states, device=hidden_states.device).scatter_(-1, indices, 1.0)
        else:
            # Fix inaccurate data transfer timing measurement
            torch.cuda.synchronize()
            
            start_time = time.perf_counter()
            logits = sgx_ours.perform_logits_recover(hidden_states)
            end_time = time.perf_counter()
            print(f"----------------------------------------------------\nSgx cost(logits_recover+data_transfer): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        return logits
    
    def prepare_otp_params(self, batch_size: int, logits_to_keep: int=1):
        """
        Prepare the OTP params in SGX.
        
        Args:
            batch_size (int): The batch size.
            logits_to_keep (int, optional): The number of logits to keep. Defaults to 1.
        """
        # Initialize otp with random values
        otp = torch.randn((batch_size, logits_to_keep, self.config.hidden_size), device=self.device)
        otp_logits = self.lm_head(otp)
        otp = otp / self.transformer.ln_f.weight.data
        
        if self.simulate:
            self.otp = otp
            self.otp_logits = otp_logits
        else:
            if not sgx_ours.prepare_otp_params(otp, otp_logits):
                raise RuntimeError("Failed to initialize SGX OTP params")    
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs,
    ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
        """ 
        NOTE: In practice, you should implement an OTP manager inside SGX to generate and 
        manage OTP keys in advance, instead of generating them before each inference.
        
        When measuring performance, we omit this overhead because it can be fully implemented within SGX.
        """
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        self.prepare_otp_params(input_ids.shape[0], input_ids[:, slice_indices].shape[1])
        
        # Measure the time spent in the forward function
        start_time = time.perf_counter()
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        output_attentions = output_attentions if output_attentions is not None else self.transformer.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.transformer.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.transformer.config.use_cache
        return_dict = return_dict if return_dict is not None else self.transformer.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            self.transformer.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            batch_size = input_ids.shape[0]
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size = inputs_embeds.shape[0]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])

        if self.transformer.gradient_checkpointing and self.transformer.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
        if use_cache:
            if past_key_values is None:
                past_key_values = DynamicCache(config=self.transformer.config)
            elif isinstance(past_key_values, tuple):
                logger.warning_once(
                    "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
                    "You should pass an instance of `Cache` instead, e.g. "
                    "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
                )
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)

            if self.transformer.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
                past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.transformer.config))

        if inputs_embeds is None:
            inputs_embeds = self.transformer.wte(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        position_embeds = self.transformer.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)

        # Attention mask.
        # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
        if attention_mask is not None and attention_mask.ndim < 4:
            attention_mask = attention_mask.view(batch_size, -1)

        causal_mask = create_causal_mask(
            config=self.transformer.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        _use_sdpa = self.transformer._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
        if self.transformer.config.add_cross_attention and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            if _use_sdpa:
                encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
                    mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
                )
            elif self.transformer._attn_implementation != "flash_attention_2":
                encoder_attention_mask = self.transformer.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # head_mask has shape n_layer x batch x n_heads x N x N
        head_mask = self.transformer.get_head_mask(head_mask, self.transformer.config.n_layer)

        if token_type_ids is not None:
            token_type_embeds = self.transformer.wte(token_type_ids)
            hidden_states = hidden_states + token_type_embeds

        hidden_states = self.transformer.drop(hidden_states)

        output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)

        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.transformer.config.add_cross_attention else None
        all_hidden_states = () if output_hidden_states else None
        for i, block in enumerate(self.transformer.h):
            # Model parallel
            if self.transformer.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                if isinstance(head_mask, torch.Tensor):
                    head_mask = head_mask.to(hidden_states.device)
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if i == 0:
                outputs = self._custom_first_decoder_forward(
                    block,
                    hidden_states,
                    past_key_values if not (self.transformer.gradient_checkpointing and self.transformer.training) else None,
                    cache_position,
                    causal_mask,
                    head_mask[i],
                    encoder_hidden_states,  # as a positional argument for gradient checkpointing
                    encoder_attention_mask=encoder_attention_mask,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    **kwargs,
                )
            else:
                outputs = block(
                    hidden_states,
                    past_key_values if not (self.transformer.gradient_checkpointing and self.transformer.training) else None,
                    cache_position,
                    causal_mask,
                    head_mask[i],
                    encoder_hidden_states,  # as a positional argument for gradient checkpointing
                    encoder_attention_mask=encoder_attention_mask,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    **kwargs,
                )

            hidden_states = outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[1],)
                if self.transformer.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (outputs[2],)

            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.transformer.model_parallel:
                for k, v in self.transformer.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.transformer.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

        hidden_states = self.transformer.ln_f(hidden_states)

        hidden_states = hidden_states.view(output_shape)
        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        past_key_values = past_key_values if use_cache else None
        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
                if v is not None
            )

        transformer_outputs = BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

        hidden_states = transformer_outputs[0]
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            hidden_states = hidden_states.to(self.lm_head.weight.device)

        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        hidden_states = hidden_states[:, slice_indices, :]
        hidden_states = self._recover_and_otp(hidden_states)
        hidden_states = self.lm_head(hidden_states)
        logits = self._recover_logits(hidden_states)

        loss = None
        if labels is not None:
            # Flatten the tokens
            loss = self.loss_function(
                logits,
                labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )

        if not return_dict:
            output = (logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        outputs = CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )
        torch.cuda.synchronize()
        print(f"----------------------------------------------------\nForward cost: {(time.perf_counter() - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
        
        return outputs
        
    
def convert_to_custom_model(original_model, v_list0, indices_list0, v_list, indices_list, optimized_stage=0, simulate=True):
    """
    Convert model to custom model (support model: Qwen3ForSequenceClassification).
    Args:
        original_model: Original model.
        v_list0: List of v vectors for the first obfuscation stage.
        indices_list0: List of indices for the first obfuscation stage.
        v_list: List of v vectors for the second obfuscation stage.
        indices_list: List of indices for the second obfuscation stage.
        optimized_stage: Stage of input obfuscation to apply (-1, 0, 1, 2 or 3). Default is 2 (the fastest).
        simulate: Whether to simulate obfuscation or run in SGX. Default is True.
    Returns:
        custom_model: CustomXXX instance.
    """
    config = original_model.config
    
    for i in range(len(v_list0)):
        for j in range(len(v_list0[0])):
            v_list0[i][j] = v_list0[i][j] / torch.norm(v_list0[i][j], p=2)
    
    for i in range(len(v_list)):
        for j in range(len(v_list[0])):
            v_list[i][j] = v_list[i][j] / torch.norm(v_list[i][j], p=2)
            
    for i in range(len(indices_list0)):
        # NOTE: indices_list0 should be inversed
        indices_list0[i] = torch.argsort(indices_list0[i])
        
    
    if simulate:
        if isinstance(original_model, Qwen3ForSequenceClassification):
            custom_model = CustomQwen3ForSequenceClassification(config, v_list0, indices_list0, v_list, indices_list, optimized_stage, simulate)
        elif isinstance(original_model, Qwen3ForCausalLM):
            custom_model = CustomQwen3ForCausalLM(config, v_list0, indices_list0, v_list, indices_list, optimized_stage, simulate)
        elif isinstance(original_model, LlamaForSequenceClassification):
            custom_model = CustomLlamaForSequenceClassification(config, v_list0, indices_list0, v_list, indices_list, optimized_stage, simulate)
        elif isinstance(original_model, LlamaForCausalLM):
            custom_model = CustomLlamaForCausalLM(config, v_list0, indices_list0, v_list, indices_list, optimized_stage, simulate)
        elif isinstance(original_model, Gemma3TextForSequenceClassification):
            custom_model = CustomGemma3TextForSequenceClassification(config, v_list0, indices_list0, v_list, indices_list, optimized_stage, simulate)
        elif isinstance(original_model, Gemma3ForCausalLM):
            custom_model = CustomGemma3ForCausalLM(config, v_list0, indices_list0, v_list, indices_list, optimized_stage, simulate)
        elif isinstance(original_model, GPT2ForSequenceClassification):
            custom_model = CustomGPT2ForSequenceClassification(config, v_list0, indices_list0, v_list, indices_list, optimized_stage, simulate)
        elif isinstance(original_model, GPT2LMHeadModel):
            custom_model = CustomGPT2LMHeadModel(config, v_list0, indices_list0, v_list, indices_list, optimized_stage, simulate)
        else:
            raise ValueError(f"Unsupported model type: {type(original_model)}")
    else:
        if isinstance(original_model, Qwen3ForSequenceClassification):
            custom_model = CustomQwen3ForSequenceClassification(config, None, None, None, None, optimized_stage, simulate)
        elif isinstance(original_model, Qwen3ForCausalLM):
            custom_model = CustomQwen3ForCausalLM(config, None, None, None, None, optimized_stage, simulate)
        elif isinstance(original_model, LlamaForSequenceClassification):
            custom_model = CustomLlamaForSequenceClassification(config, None, None, None, None, optimized_stage, simulate)
        elif isinstance(original_model, LlamaForCausalLM):
            custom_model = CustomLlamaForCausalLM(config, None, None, None, None, optimized_stage, simulate)
        elif isinstance(original_model, Gemma3TextForSequenceClassification):
            custom_model = CustomGemma3TextForSequenceClassification(config, None, None, None, None, optimized_stage, simulate)
        elif isinstance(original_model, Gemma3ForCausalLM):
            custom_model = CustomGemma3ForCausalLM(config, None, None, None, None, optimized_stage, simulate)
        elif isinstance(original_model, GPT2ForSequenceClassification):
            custom_model = CustomGPT2ForSequenceClassification(config, None, None, None, None, optimized_stage, simulate)
        elif isinstance(original_model, GPT2LMHeadModel):
            custom_model = CustomGPT2LMHeadModel(config, None, None, None, None, optimized_stage, simulate)
        else:
            raise ValueError(f"Unsupported model type: {type(original_model)}")

    custom_model.load_state_dict(original_model.state_dict())
    
    device = next(original_model.parameters()).device
    custom_model = custom_model.to(device)
    
    r"""
    Pass obfuscation parameters to SGX when not in simulation mode.
    In practice, you should pass this parameter through an encrypted file on disk 
    and decrypt it within SGX to prevent leakage.
    """
    if not simulate:
        # Initialize SGX and pass the parameters
        if not sgx_ours.prepare_input_obf_params(v_list0, indices_list0, v_list, indices_list):
            raise RuntimeError("Failed to initialize SGX input obfuscation params")            
    
    return custom_model