import os
import ctypes
import numpy as np
import torch
import atexit
import time
from transformers import Gemma3PreTrainedModel

import logging

# Custom formatter with colors for different log levels
class ColoredFormatter(logging.Formatter):
    # ANSI color codes
    COLORS = {
        logging.DEBUG: '\033[94m',    # Blue
        logging.INFO: '\033[92m',     # Green
        logging.WARNING: '\033[93m',  # Yellow
        logging.ERROR: '\033[91m',    # Red
        logging.CRITICAL: '\033[95m', # Purple
        'RESET': '\033[0m'            # Reset color
    }
    
    def format(self, record):
        # Get the color for this log level
        color = self.COLORS.get(record.levelno, self.COLORS['RESET'])
        # Format the record with color
        original_formatter = logging.Formatter(f'{color}[%(levelname)s]{self.COLORS["RESET"]} %(name)s: %(message)s')
        return original_formatter.format(record)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = ColoredFormatter()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# Path to SGX library
SGX_LIB_PATH = os.path.join(os.path.dirname(__file__), '../../sgx/groupcover/build/sgx_groupcover.so')
ENCLAVE_PATH = os.path.join(os.path.dirname(__file__), '../../sgx/groupcover/build/enclave.signed.so')

# Global SGX instance
sgx_instance = None

class SGX(object):
    def __init__(self):
        self.sgx_lib_path = SGX_LIB_PATH
        self.enclave_path = ENCLAVE_PATH
        self.sgx_lib = None
        self.enclave_id = None
        
        self.norm_enum = {
                "input_layernorm": 0,
                "post_attention_layernorm": 1,
                "q_norm": 2,
                "k_norm": 3,
                "ln_1": 4,
                "ln_2": 5,
                "ln_f": 6,
                "norm": 7,
                "pre_feedforward_layernorm": 8,
                "post_feedforward_layernorm": 9,
            }
        self.restore_enum = {
                "q_proj": 0,
                "k_proj": 1,
                "v_proj": 2,
                "o_proj": 3,
                "c_fc": 4,
                "c_proj": 5,
                "gate_proj": 6,
                "up_proj": 7,
                "down_proj": 8,
            }
        
        self.sgx_compute_time = 0.0
        self.all_time = 0.0
        
        self.init_sgx()
    
    def init_sgx(self):
        """Initialize SGX enclave and load the obfuscation library"""
        try:
            # loading SGX lib
            logger.info(f"Loading SGX library from path: {self.sgx_lib_path}")
            self.sgx_lib = ctypes.CDLL(self.sgx_lib_path, mode=ctypes.RTLD_GLOBAL)
            
            # define function type
            self.sgx_lib.init_enclave.argtypes = [ctypes.c_char_p]
            self.sgx_lib.init_enclave.restype = ctypes.c_int
            
            self.sgx_lib.get_enclave_id.restype = ctypes.c_ulonglong
            
            self.sgx_lib.get_sgx_exe_time.restype = ctypes.c_double
            
            self.sgx_lib.reset_sgx_exe_time.restype = ctypes.c_void_p
            
            self.sgx_lib.prepare_obf_params.argtypes = [ctypes.c_ulonglong, ctypes.c_int, ctypes.c_void_p]
            self.sgx_lib.prepare_obf_params.restype = ctypes.c_int
            
            self.sgx_lib.prepare_norm_params.argtypes = [ctypes.c_ulonglong, ctypes.c_int, ctypes.c_void_p]
            self.sgx_lib.prepare_norm_params.restype = ctypes.c_int
            
            self.sgx_lib.restore.argtypes = [ctypes.c_ulonglong, ctypes.POINTER(ctypes.c_float), ctypes.c_int, ctypes.c_int, 
                                                     ctypes.POINTER(ctypes.c_float), ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t,
                                                     ctypes.POINTER(ctypes.c_float)]
            self.sgx_lib.restore.restype = ctypes.c_int
            
            self.sgx_lib.silu_activation.argtypes = [ctypes.c_ulonglong, ctypes.POINTER(ctypes.c_float), 
                                                     ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t, 
                                                     ctypes.POINTER(ctypes.c_float)]
            self.sgx_lib.silu_activation.restype = ctypes.c_int
            
            self.sgx_lib.gelu_tanh_activation.argtypes = [ctypes.c_ulonglong, ctypes.POINTER(ctypes.c_float), 
                                                         ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t, 
                                                         ctypes.POINTER(ctypes.c_float)]
            self.sgx_lib.gelu_tanh_activation.restype = ctypes.c_int
            
            self.sgx_lib.new_gelu_activation.argtypes = [ctypes.c_ulonglong, ctypes.POINTER(ctypes.c_float), 
                                                        ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t, 
                                                        ctypes.POINTER(ctypes.c_float)]
            self.sgx_lib.new_gelu_activation.restype = ctypes.c_int
            
            self.sgx_lib.norm.argtypes = [ctypes.c_ulonglong, ctypes.POINTER(ctypes.c_float), ctypes.c_int, 
                                             ctypes.c_int, ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t, 
                                             ctypes.POINTER(ctypes.c_float)]
            self.sgx_lib.norm.restype = ctypes.c_int
            
            self.sgx_lib.destroy_enclave.restype = ctypes.c_void_p
            
            # init SGX enclave
            logger.info(f"Initializing SGX enclave from path: {self.enclave_path}")
            ret = self.sgx_lib.init_enclave(self.enclave_path.encode('utf-8'))
            
            if ret != 0:
                raise RuntimeError(f"Failed to initialize SGX enclave: {ret}")
                
            self.enclave_id = self.sgx_lib.get_enclave_id()
                
        except Exception as e:
            raise RuntimeError(f"Failed to initialize SGX: {str(e)}") from e

    def destroy_enclave(self):
        """
        Cleanup SGX resources
        """
        if self.sgx_lib:
            self.sgx_lib.destroy_enclave()

    def prepare_obf_params(self, obf_param):
        """
        Prepare obfuscation parameters in SGX
        
        Args:
            obf_param: Dictionary containing all obfuscation parameters for all layers
        """
        if self.sgx_lib is None:
            raise RuntimeError("SGX library not initialized")
        
        try:
            def convert_to_c_types(params: list[dict]):
                class Block(ctypes.Structure):
                    _fields_ = [
                        ("size", ctypes.c_size_t),
                        ("data", ctypes.POINTER(ctypes.c_float)),
                    ]
                
                class ObfParam(ctypes.Structure):
                    _fields_ = [
                        ("perm_size", ctypes.c_size_t),
                        ("block_count", ctypes.c_size_t),
                        ("blocks", ctypes.POINTER(Block)),
                        ("permutation", ctypes.POINTER(ctypes.c_int)),
                    ]
                
                class ObfParamArray(ctypes.Structure):
                    _fields_ = [
                        ("count", ctypes.c_size_t),
                        ("params", ctypes.POINTER(ObfParam)),
                    ]
                count = len(params)
                obf_param_array = ObfParamArray()
                obf_param_array.count = count
                obf_param_array.params = (ObfParam * count)()
                
                data_arrays = []
                for i, param in enumerate(params):
                    obf_p = ObfParam()
                    obf_p.perm_size = param["permutation"].shape[0]
                    obf_p.block_count = len(param["blocks"])
                    
                    obf_p.blocks = (Block * obf_p.block_count)()
                    obf_p.permutation = (ctypes.c_int * obf_p.perm_size)()
                    
                    for j in range(obf_p.block_count):
                        block = obf_p.blocks[j]
                        block.size = param["blocks"][j].shape[0]
                        # Allocate a 1D array of floats
                        data_array = (ctypes.c_float * (block.size * block.size))()
                        # Cast to pointer to float
                        block.data = ctypes.cast(data_array, ctypes.POINTER(ctypes.c_float))
                        block_flatten = torch.linalg.inv(param["blocks"][j]).flatten()
                        for k in range(block.size * block.size):
                            block.data[k] = block_flatten[k].item()
                        data_arrays.append(data_array)
                    for j in range(obf_p.perm_size):
                        obf_p.permutation[j] = param["permutation"][j].item()
                    
                    obf_param_array.params[i] = obf_p
                    
                return obf_param_array, data_arrays
            
            # Store all data arrays to prevent garbage collection
            all_data_arrays = []
            for name, params in obf_param.items():
                sgx_param, data_arrays = convert_to_c_types(params)
                # Add data arrays to the global list
                all_data_arrays.extend(data_arrays)
                ret = self.sgx_lib.prepare_obf_params(
                    self.enclave_id,
                    self.restore_enum[name],
                    ctypes.byref(sgx_param)
                )
                if ret != 0:
                    logger.error(f"Failed to prepare {name} obfuscation parameters in SGX: {ret}")
                    return False
        
            # Keep references alive until the end of the function
            _ = all_data_arrays
        
            return True
        
        except Exception as e:
            raise RuntimeError(f"Error preparing obfuscation parameters in SGX: {e}")

    def prepare_norm_params(self, model, weight_add_1=False):
        """
        Prepare normalization parameters in SGX
        
        Args:
            model: The model to prepare normalization parameters for
        """
        if self.sgx_lib is None:
            raise RuntimeError("SGX library not initialized")
        
        try:
            def convert_to_c_types(norm_layers, weight_add_1):
                class NormParam(ctypes.Structure):
                    _fields_ = [
                        ("size", ctypes.c_size_t),
                        ("weight", ctypes.POINTER(ctypes.c_float)),
                        ("bias", ctypes.POINTER(ctypes.c_float)),
                        ("eps", ctypes.c_float),
                    ]
                class NormParamArray(ctypes.Structure):
                    _fields_ = [
                        ("count", ctypes.c_size_t),
                        ("params", ctypes.POINTER(NormParam)),
                    ]
                norm_param_array = NormParamArray()
                norm_param_array.count = len(norm_layers)
                norm_param_array.params = (NormParam * norm_param_array.count)()
                
                for i, layer in enumerate(norm_layers):
                    norm_param = NormParam()
                    norm_param.size = layer.weight.shape[0]
                    norm_param.weight = (ctypes.c_float * norm_param.size)()
                    if hasattr(layer, "bias") and layer.bias is not None:
                        norm_param.bias = (ctypes.c_float * norm_param.size)()
                    else:
                        norm_param.bias = None
                    
                    if hasattr(layer, "variance_epsilon"):
                        norm_param.eps = layer.variance_epsilon
                    elif hasattr(layer, "eps"):
                        norm_param.eps = layer.eps
                    else:
                        raise ValueError(f"Normalization layer {i} does not have 'variance_epsilon' or 'eps' attribute")
                    
                    add_item = 1.0 if weight_add_1 else 0.0
                    for j in range(norm_param.size):
                        norm_param.weight[j] = layer.weight[j].item() + add_item
                        if hasattr(layer, "bias") and layer.bias is not None:
                            norm_param.bias[j] = layer.bias[j].item()
                    
                    norm_param_array.params[i] = norm_param
                    
                return norm_param_array
            
            norms = {
                "input_layernorm": [],
                "post_attention_layernorm": [],
                "q_norm": [],
                "k_norm": [],
                "ln_1": [],
                "ln_2": [],
                "ln_f": [],
                "norm": [],
                "pre_feedforward_layernorm": [],
                "post_feedforward_layernorm": [],
            }
            
            layers = model.layers if hasattr(model, "layers") else model.h
            
            for layer in layers:
                if hasattr(layer, "input_layernorm"):
                    norms["input_layernorm"].append(layer.input_layernorm)
                if hasattr(layer, "post_attention_layernorm"):
                    norms["post_attention_layernorm"].append(layer.post_attention_layernorm)
                if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "q_norm"):
                    norms["q_norm"].append(layer.self_attn.q_norm)
                if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "k_norm"):
                    norms["k_norm"].append(layer.self_attn.k_norm)
                if hasattr(layer, "ln_1"):
                    norms["ln_1"].append(layer.ln_1)
                if hasattr(layer, "ln_2"):
                    norms["ln_2"].append(layer.ln_2)
                if hasattr(layer, "pre_feedforward_layernorm"):
                    norms["pre_feedforward_layernorm"].append(layer.pre_feedforward_layernorm)
                if hasattr(layer, "post_feedforward_layernorm"):
                    norms["post_feedforward_layernorm"].append(layer.post_feedforward_layernorm)
            if hasattr(model, "ln_f"):
                norms["ln_f"].append(model.ln_f)
            if hasattr(model, "norm"):
                norms["norm"].append(model.norm)
            
            for name, norm_layers in norms.items():
                if len(norm_layers) == 0:
                    continue

                sgx_norm_param = convert_to_c_types(norm_layers, weight_add_1)
                ret = self.sgx_lib.prepare_norm_params(
                    self.enclave_id,
                    self.norm_enum[name],
                    ctypes.byref(sgx_norm_param)
                )
                if ret != 0:
                    logger.error(f"Failed to prepare {name} normalization parameters in SGX: {ret}")
                    return False
        
            return True
        
        except Exception as e:
            raise RuntimeError(f"Error preparing normalization parameters in SGX: {e}")

    def norm(self, x, layer_idx, norm_type):
        """
        RMSNorm And LayerNorm implementation in SGX using pre-stored parameters with identifier
        
        Args:
            x: Tensor of shape (batch_size, seq_length, hidden_size)
            layer_idx: Integer index of the layer
            norm_type: String type of the norm layer
        
        Returns:
            Output tensor of shape (batch_size, seq_length, hidden_size)
        """
        if self.sgx_lib is None:
            raise RuntimeError("SGX library not initialized")
        
        try:
            torch.cuda.synchronize()
            start_time = time.perf_counter()
            # Convert tensor to numpy array
            x_np = x.cpu().detach().numpy().astype(np.float32)
            
            # Get dimensions and handle different shapes
            original_shape = x_np.shape
            if len(original_shape) == 4:
                # For (batch_size, seq_length, num_heads, head_dim) shape
                batch_size, seq_length, num_heads, head_dim = original_shape
                hidden_size = head_dim
                # Reshape to (batch_size * seq_length * num_heads, 1, head_dim) for SGX processing
                x_np_reshaped = x_np.reshape(-1, 1, hidden_size)
                output_np_reshaped = np.zeros_like(x_np_reshaped, dtype=np.float32)
                output_ptr = output_np_reshaped.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
                # Use adjusted batch size for SGX call
                sgx_batch_size = batch_size * seq_length * num_heads
                sgx_seq_length = 1
            else:
                # For (batch_size, seq_length, hidden_size) shape
                batch_size, seq_length, hidden_size = original_shape
                output_np = np.zeros_like(x_np, dtype=np.float32)
                output_ptr = output_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
                sgx_batch_size = batch_size
                sgx_seq_length = seq_length
            
            # Convert norm_type to enum
            norm_type_enum = self.norm_enum.get(norm_type, -1)
            
            if norm_type_enum == -1:
                raise ValueError(f"No such type norm: {norm_type}")
            
            # Call SGX function
            if len(original_shape) == 4:
                ret = self.sgx_lib.norm(
                    self.enclave_id,
                    x_np_reshaped.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
                    layer_idx,
                    norm_type_enum,
                    sgx_batch_size,
                    sgx_seq_length,
                    hidden_size,
                    output_ptr
                )
            else:
                ret = self.sgx_lib.norm(
                    self.enclave_id,
                    x_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
                    layer_idx,
                    norm_type_enum,
                    sgx_batch_size,
                    sgx_seq_length,
                    hidden_size,
                    output_ptr
                )
        
            if ret != 0:
                logger.error(f"Failed to execute Norm in SGX: {ret}")
                raise RuntimeError(f"Norm failed with code: {ret}")
        
            # Convert back to PyTorch tensor
            if len(original_shape) == 4:
                # Reshape back to original 4D shape
                output = torch.from_numpy(output_np_reshaped.reshape(original_shape)).to(x.device)
            else:
                output = torch.from_numpy(output_np).to(x.device)
                
            self.all_time += (time.perf_counter() - start_time) * 1000
            return output
        
        except Exception as e:
            raise RuntimeError(f"Error executing Norm in SGX: {e}")

    def restore(self, x, layer_idx, layer_type, bias=None):
        """
        Implement _restore function in SGX
        
        Args:
            x: Tensor of shape (batch_size, seq_length, hidden_size)
            bias: Tensor of shape (hidden_size,), optional
        
        Returns:
            Output tensor of shape (batch_size, seq_length, hidden_size)
        """
        if self.sgx_lib is None:
            raise RuntimeError("SGX library not initialized")
        
        try:
            torch.cuda.synchronize()
            start_time = time.perf_counter()
            layer_enum = self.restore_enum.get(layer_type, -1)
            
            # Convert tensors to numpy arrays
            x_np = x.cpu().detach().numpy().astype(np.float32)
            
            # Get dimensions
            batch_size, seq_length, hidden_size_x = x_np.shape
            
            # Prepare bias
            bias_ptr = None
            if bias is not None:
                bias_np = bias.cpu().detach().numpy().astype(np.float32)
                bias_ptr = bias_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            
            # Allocate output buffer
            output_np = np.zeros_like(x_np, dtype=np.float32)
            output_ptr = output_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            
            # Call SGX function
            ret = self.sgx_lib.restore(
                self.enclave_id,
                x_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
                layer_idx,
                layer_enum,
                bias_ptr,
                batch_size,
                seq_length,
                hidden_size_x,
                output_ptr
            )
        
            if ret != 0:
                logger.error(f"Failed to execute restore function in SGX: {ret}")
                raise RuntimeError(f"Restore function failed with code: {ret}")
        
            # Convert back to PyTorch tensor
            output = torch.from_numpy(output_np).to(x.device)
            
            self.all_time += (time.perf_counter() - start_time) * 1000
            return output
        
        except Exception as e:
            raise RuntimeError(f"Error executing restore function in SGX: {e}")

    def silu_activation(self, x):
        """
        SiLU activation implementation in SGX
        
        Args:
            x: Tensor of shape (batch_size, seq_length, hidden_size)
        
        Returns:
            Output tensor of shape (batch_size, seq_length, hidden_size)
        """
        if self.sgx_lib is None:
            raise RuntimeError("SGX library not initialized")
        
        try:
            torch.cuda.synchronize()
            start_time = time.perf_counter()
            # Convert tensor to numpy array
            x_np = x.cpu().detach().numpy().astype(np.float32)
            
            # Get dimensions
            batch_size, seq_length, hidden_size = x_np.shape
            
            # Allocate output buffer
            output_np = np.zeros_like(x_np, dtype=np.float32)
            output_ptr = output_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            
            # Call SGX function
            ret = self.sgx_lib.silu_activation(
                self.enclave_id,
                x_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
                batch_size,
                seq_length,
                hidden_size,
                output_ptr
            )
        
            if ret != 0:
                logger.error(f"Failed to execute SiLU activation in SGX: {ret}")
                raise RuntimeError(f"SiLU activation failed with code: {ret}")
        
            # Convert back to PyTorch tensor
            output = torch.from_numpy(output_np).to(x.device)
            self.all_time += (time.perf_counter() - start_time) * 1000
            return output
        
        except Exception as e:
            raise RuntimeError(f"Error executing SiLU activation in SGX: {e}")

    def gelu_tanh_activation(self, x):
        """
        GELU Tanh activation implementation in SGX
        
        Args:
            x: Tensor of shape (batch_size, seq_length, hidden_size)
        
        Returns:
            Output tensor of shape (batch_size, seq_length, hidden_size)
        """
        if self.sgx_lib is None:
            raise RuntimeError("SGX library not initialized")
        
        try:
            torch.cuda.synchronize()
            start_time = time.perf_counter()
            # Convert tensor to numpy array
            x_np = x.cpu().detach().numpy().astype(np.float32)
            
            # Get dimensions
            batch_size, seq_length, hidden_size = x_np.shape
            
            # Allocate output buffer
            output_np = np.zeros_like(x_np, dtype=np.float32)
            output_ptr = output_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            
            # Call SGX function
            ret = self.sgx_lib.gelu_tanh_activation(
                self.enclave_id,
                x_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
                batch_size,
                seq_length,
                hidden_size,
                output_ptr
            )
        
            if ret != 0:
                logger.error(f"Failed to execute GELU Tanh activation in SGX: {ret}")
                raise RuntimeError(f"GELU Tanh activation failed with code: {ret}")
        
            # Convert back to PyTorch tensor
            output = torch.from_numpy(output_np).to(x.device)
            self.all_time += (time.perf_counter() - start_time) * 1000
            return output
        
        except Exception as e:
            raise RuntimeError(f"Error executing GELU Tanh activation in SGX: {e}")

    def new_gelu_activation(self, x):
        """
        New GELU activation implementation in SGX
        
        Args:
            x: Tensor of shape (batch_size, seq_length, hidden_size)
        
        Returns:
            Output tensor of shape (batch_size, seq_length, hidden_size)
        """
        if self.sgx_lib is None:
            raise RuntimeError("SGX library not initialized")
        
        try:
            torch.cuda.synchronize()
            start_time = time.perf_counter()
            # Convert tensor to numpy array
            x_np = x.cpu().detach().numpy().astype(np.float32)
            
            # Get dimensions
            batch_size, seq_length, hidden_size = x_np.shape
            
            # Allocate output buffer
            output_np = np.zeros_like(x_np, dtype=np.float32)
            output_ptr = output_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            
            # Call SGX function
            ret = self.sgx_lib.new_gelu_activation(
                self.enclave_id,
                x_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
                batch_size,
                seq_length,
                hidden_size,
                output_ptr
            )
        
            if ret != 0:
                logger.error(f"Failed to execute New GELU activation in SGX: {ret}")
                raise RuntimeError(f"New GELU activation failed with code: {ret}")
        
            # Convert back to PyTorch tensor
            output = torch.from_numpy(output_np).to(x.device)
            self.all_time += (time.perf_counter() - start_time) * 1000
            return output
        
        except Exception as e:
            raise RuntimeError(f"Error executing New GELU activation in SGX: {e}")
        
    def reset_time(self):
        self.sgx_compute_time = 0.0
        self.all_time = 0.0
        self.sgx_lib.reset_sgx_exe_time()
        
    def get_exe_time(self):
        if self.sgx_compute_time == 0.0:
            self.sgx_compute_time = self.sgx_lib.get_sgx_exe_time()
        return self.sgx_compute_time, self.all_time
            
            

def get_sgx_instance():
    """
    Get singleton SGX instance
    """
    global sgx_instance
    if sgx_instance is None:
        sgx_instance = SGX()
    return sgx_instance

# Register cleanup function
atexit.register(lambda: sgx_instance.destroy_enclave() if sgx_instance else None)
