import os
import ctypes
import numpy as np
import torch
import atexit
import time

import logging

# Custom formatter with colors for different log levels
class ColoredFormatter(logging.Formatter):
    # ANSI color codes
    COLORS = {
        logging.DEBUG: '\033[94m',    # Blue
        logging.INFO: '\033[92m',     # Green
        logging.WARNING: '\033[93m',  # Yellow
        logging.ERROR: '\033[91m',    # Red
        logging.CRITICAL: '\033[95m', # Purple
        'RESET': '\033[0m'            # Reset color
    }
    
    def format(self, record):
        # Get the color for this log level
        color = self.COLORS.get(record.levelno, self.COLORS['RESET'])
        # Format the record with color
        original_formatter = logging.Formatter(f'{color}[%(levelname)s]{self.COLORS["RESET"]} %(name)s: %(message)s')
        return original_formatter.format(record)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = ColoredFormatter()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# Path to SGX library
SGX_LIB_PATH = os.path.join(os.path.dirname(__file__), '../../sgx/ours/build/sgx_ours.so')
ENCLAVE_PATH = os.path.join(os.path.dirname(__file__), '../../sgx/ours/build/enclave.signed.so')

# Global SGX instance
sgx_instance = None

class SGX(object):
    def __init__(self):
        self.sgx_lib_path = SGX_LIB_PATH
        self.enclave_path = ENCLAVE_PATH
        self.sgx_lib = None
        self.enclave_id = None
        
        self.init_sgx()
    
    def init_sgx(self):
        """Initialize SGX enclave and load the obfuscation library"""
        
        try:
            # loading SGX lib
            logger.info(f"Loading SGX library from path: {self.sgx_lib_path}")
            self.sgx_lib = ctypes.CDLL(self.sgx_lib_path, mode=ctypes.RTLD_GLOBAL)
            
            # define function type
            self.sgx_lib.init_enclave.argtypes = [ctypes.c_char_p]
            self.sgx_lib.init_enclave.restype = ctypes.c_int
            
            self.sgx_lib.get_enclave_id.restype = ctypes.c_ulonglong
            
            self.sgx_lib.prepare_input_obf_params.argtypes = [ctypes.c_ulonglong, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
            self.sgx_lib.prepare_input_obf_params.restype = ctypes.c_int
            
            self.sgx_lib.prepare_otp_params.argtypes = [ctypes.c_ulonglong, ctypes.c_void_p, ctypes.c_void_p]
            self.sgx_lib.prepare_otp_params.restype = ctypes.c_int
            
            self.sgx_lib.perform_obfuscation.argtypes = [ctypes.c_ulonglong, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), 
                                                       ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t, 
                                                       ctypes.POINTER(ctypes.c_float), ctypes.c_int]
            self.sgx_lib.perform_obfuscation.restype = ctypes.c_int
            
            self.sgx_lib.perform_otp.argtypes = [ctypes.c_ulonglong, ctypes.POINTER(ctypes.c_float), 
                                                       ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t, 
                                                       ctypes.POINTER(ctypes.c_float)]
            self.sgx_lib.perform_otp.restype = ctypes.c_int
            
            self.sgx_lib.perform_logits_recover.argtypes = [ctypes.c_ulonglong, ctypes.POINTER(ctypes.c_float),
                                                            ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t, 
                                                            ctypes.POINTER(ctypes.c_int64)]
            self.sgx_lib.perform_logits_recover.restype = ctypes.c_int
            
            self.sgx_lib.destroy_enclave.restype = ctypes.c_void_p
            
            # init SGX enclave
            logger.info(f"Initializing SGX enclave from path: {self.enclave_path}")
            ret = self.sgx_lib.init_enclave(self.enclave_path.encode('utf-8'))
            
            if ret != 0:
                raise RuntimeError(f"Failed to initialize SGX enclave: {ret}")
                
            self.enclave_id = self.sgx_lib.get_enclave_id()
            # print(f"enclave id: {self.enclave_id}")
                
        except Exception as e:
            raise RuntimeError(f"Failed to initialize SGX: {str(e)}") from e

    def destroy_enclave(self):
        """
        Cleanup SGX resources
        """
        if self.sgx_lib:
            self.sgx_lib.destroy_enclave()

    def prepare_input_obf_params(self, v_list0, indices_list0, v_list, indices_list):
        """
        Initialize input obfuscation parameters in SGX
        
        Args:
            v_list0: List of lists of vectors for inverse obfuscation
            indices_list0: List of indices for inverse obfuscation
            v_list: List of lists of vectors for obfuscation
            indices_list: List of indices for obfuscation
        """
        if self.sgx_lib is None:
            if not init_sgx():
                return False
        
        try:
            # Convert vectors to C structures
            def convert_vector_list(v_list):
                class Vector(ctypes.Structure):
                    _fields_ = [
                        ("vector_size", ctypes.c_size_t),
                        ("vectors", ctypes.POINTER(ctypes.c_float))
                    ]
                
                class VectorList(ctypes.Structure):
                    _fields_ = [
                        ("num_vectors", ctypes.c_size_t),
                        ("vector_list", ctypes.POINTER(Vector))
                    ]
                    
                class VectorListList(ctypes.Structure):
                    _fields_ = [
                        ("num_lists", ctypes.c_size_t),
                        ("vector_list_list", ctypes.POINTER(VectorList))
                    ]
                
                num_lists = len(v_list)
                vector_lists = (VectorList * num_lists)()
                
                for i, vecs in enumerate(v_list):
                    num_vecs = len(vecs)
                    vec_list = (Vector * num_vecs)()
                    for j, vec in enumerate(vecs):
                        # Handle different vector types
                        if isinstance(vec, (torch.Tensor, np.ndarray)):
                            # Handle scalar tensors
                            if vec.ndim == 0:
                                vec_size = 1
                                vec_array = (ctypes.c_float * vec_size)()
                                vec_array[0] = vec.item() if hasattr(vec, 'item') else float(vec)
                            else:
                                vec_size = vec.shape[0]
                                vec_array = (ctypes.c_float * vec_size)()
                                for k, vec_val in enumerate(vec):
                                    vec_array[k] = vec_val.item() if hasattr(vec_val, 'item') else float(vec_val)
                        elif isinstance(vec, (list, tuple)):
                            # Handle Python lists/tuples
                            vec_size = len(vec)
                            vec_array = (ctypes.c_float * vec_size)()
                            for k, vec_val in enumerate(vec):
                                vec_array[k] = vec_val.item() if hasattr(vec_val, 'item') else float(vec_val)
                        else:
                            raise ValueError(f"Unsupported vector type: {type(vec)}")
                        
                        vec_list[j].vector_size = vec_size
                        vec_list[j].vectors = vec_array
                    
                    vector_lists[i].num_vectors = num_vecs
                    vector_lists[i].vector_list = vec_list
                
                ret_lists = VectorListList()
                ret_lists.num_lists = num_lists
                ret_lists.vector_list_list = vector_lists
                
                return ret_lists
        
            # Convert indices to C structures
            def convert_indices_list(indices_list):
                class Indices(ctypes.Structure):
                    _fields_ = [
                        ("num_indices", ctypes.c_size_t),
                        ("indices", ctypes.POINTER(ctypes.c_uint32))
                    ]
                
                class IndicesList(ctypes.Structure):
                    _fields_ = [
                        ("num_lists", ctypes.c_size_t),
                        ("indices_list", ctypes.POINTER(Indices))
                    ]
                
                num_lists = len(indices_list)
                indices = (Indices * num_lists)()
                
                for i, idx in enumerate(indices_list):
                    # Handle different indices types
                    if isinstance(idx, (torch.Tensor, np.ndarray)):
                        num_idx = idx.shape[0]
                        idx_array = (ctypes.c_uint32 * num_idx)()
                        for j, idx_val in enumerate(idx):
                            idx_array[j] = idx_val.item() if hasattr(idx_val, 'item') else int(idx_val)
                    elif isinstance(idx, list):
                        # Handle Python lists
                        num_idx = len(idx)
                        idx_array = (ctypes.c_uint32 * num_idx)()
                        for j, idx_val in enumerate(idx):
                            idx_array[j] = idx_val.item() if hasattr(idx_val, 'item') else int(idx_val)
                    else:
                        raise ValueError(f"Unsupported indices type: {type(idx)}")
                    
                    indices[i].num_indices = num_idx
                    indices[i].indices = idx_array
                
                indices_list_struct = IndicesList()
                indices_list_struct.num_lists = num_lists
                indices_list_struct.indices_list = indices
                
                return indices_list_struct
        
            # Convert parameters
            v_list0_struct = convert_vector_list(v_list0)
            indices_list0_struct = convert_indices_list(indices_list0)
            v_list_struct = convert_vector_list(v_list)
            indices_list_struct = convert_indices_list(indices_list)
            
            indices_list_inv = []
            for i in range(len(indices_list)):
                indices_list_inv.append(torch.argsort(indices_list[i]))
            indices_list_inv_struct = convert_indices_list(indices_list_inv)
            
            if self.enclave_id is None:
                self.enclave_id = self.sgx_lib.get_enclave_id()
            # print(f"Using enclave ID: {self.enclave_id}")
            
            # Initialize obfuscation in SGX
            ret = self.sgx_lib.prepare_input_obf_params(
                self.enclave_id,
                ctypes.byref(v_list0_struct),
                ctypes.byref(indices_list0_struct),
                ctypes.byref(v_list_struct),
                ctypes.byref(indices_list_struct),
                ctypes.byref(indices_list_inv_struct),
            )
        
            if ret != 0:
                logger.error(f"Failed to initialize obfuscation in SGX: {ret}")
                return False
        
            return True
        
        except Exception as e:
            raise RuntimeError(f"Error initializing obfuscation in SGX: {e}")
        
    def prepare_otp_params(self, otp, otp_logits):
        """
        Initialize otp parameters in SGX
        
        Args:
            otp: Tensor of shape (batch_size, logits_to_keep, hidden_size)
            otp_logits: Tensor of shape (batch_size, logits_to_keep, vocab_size)
        """
        if self.sgx_lib is None:
            if not init_sgx():
                return False

        try:
            def convert_otp_params(otp_ori):
                class OTPParams(ctypes.Structure):
                    _fields_ = [
                        ("batch_size", ctypes.c_size_t),
                        ("logits_to_keep", ctypes.c_size_t),
                        ("hidden_size", ctypes.c_size_t),
                        ("data", ctypes.POINTER(ctypes.c_float))
                    ]
                
                batch_size, logits_to_keep, hidden_size = otp_ori.shape
                otp_params = OTPParams()
                
                # NOTE: Store numpy arrays to prevent garbage collection during SGX operation!!!
                numpy_arrays = []
                
                if isinstance(otp_ori, torch.Tensor):
                    np_array = otp_ori.cpu().detach().numpy().astype(np.float32)
                    numpy_arrays.append(np_array)
                    otp_params.data = np_array.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
                elif isinstance(otp_ori, np.ndarray):
                    np_array = otp_ori.astype(np.float32)
                    numpy_arrays.append(np_array)
                    otp_params.data = np_array.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
                else:
                    raise ValueError(f"Unsupported otp type: {type(otp_ori)}")
                
                otp_params.batch_size = batch_size
                otp_params.logits_to_keep = logits_to_keep
                otp_params.hidden_size = hidden_size
                
                return otp_params, numpy_arrays
            
            def convert_otp_logits_params(otp_logits_ori):
                class OTPLogitsParams(ctypes.Structure):
                    _fields_ = [
                        ("batch_size", ctypes.c_size_t),
                        ("logits_to_keep", ctypes.c_size_t),
                        ("vocab_size", ctypes.c_size_t),
                        ("logits", ctypes.POINTER(ctypes.c_float))
                    ]
                
                batch_size, logits_to_keep, hidden_size = otp_logits_ori.shape
                otp_logits_params = OTPLogitsParams()
                
                # Store numpy arrays to prevent garbage collection during SGX operation
                numpy_arrays = []
                
                if isinstance(otp_logits_ori, torch.Tensor):
                    np_array = otp_logits_ori.cpu().detach().numpy().astype(np.float32)
                    numpy_arrays.append(np_array)
                    otp_logits_params.logits = np_array.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
                elif isinstance(otp_logits_ori, np.ndarray):
                    np_array = otp_logits_ori.astype(np.float32)
                    numpy_arrays.append(np_array)
                    otp_logits_params.logits = np_array.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
                else:
                    raise ValueError(f"Unsupported otp_logits type: {type(otp_logits_ori)}")
                
                otp_logits_params.batch_size = batch_size
                otp_logits_params.logits_to_keep = logits_to_keep
                otp_logits_params.vocab_size = hidden_size
                
                return otp_logits_params, numpy_arrays
            
            
            otp_params_struct, otp_numpy_arrays = convert_otp_params(otp)
            otp_logits_params_struct, otp_logits_numpy_arrays = convert_otp_logits_params(otp_logits)
            
            if self.enclave_id is None:
                self.enclave_id = self.sgx_lib.get_enclave_id()
            
            # Initialize obfuscation in SGX
            ret = self.sgx_lib.prepare_otp_params(
                self.enclave_id,
                ctypes.byref(otp_params_struct),
                ctypes.byref(otp_logits_params_struct)
            )
        
            if ret != 0:
                logger.error(f"Failed to initialize otp params in SGX: {ret}")
                return False
        
            return True
        
        except Exception as e:
            raise RuntimeError(f"Error initializing otp params in SGX: {e}")
        

    def perform_obfuscation(self, hidden_states, residual, optimized_stage):
        """
        Perform obfuscation operation in SGX
        
        Args:
            hidden_states: Tensor of shape (batch_size, seq_length, hidden_size)
            residual: Tensor of shape (batch_size, seq_length, hidden_size)
        
        Returns:
            Obfuscated tensor of shape (batch_size, seq_length, hidden_size)
        """
        if self.sgx_lib is None:
            if not init_sgx():
                raise RuntimeError("Failed to initialize SGX")
        
        try:
            # Convert tensors to numpy arrays
            hidden_states_np = hidden_states.cpu().detach().numpy().astype(np.float32)
            residual_np = residual.cpu().detach().numpy().astype(np.float32)
            
            # Get dimensions
            batch_size, seq_length, hidden_size = hidden_states_np.shape
            
            # Allocate output buffer
            output_np = np.zeros_like(hidden_states_np, dtype=np.float32)
            
            # Convert to C pointers
            hidden_states_ptr = hidden_states_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            residual_ptr = residual_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            output_ptr = output_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            
            if self.enclave_id is None:
                self.enclave_id = self.sgx_lib.get_enclave_id()
            
            # Call SGX function
            ret = self.sgx_lib.perform_obfuscation(
                self.enclave_id,
                hidden_states_ptr,
                residual_ptr,
                batch_size,
                seq_length,
                hidden_size,
                output_ptr,
                optimized_stage
            )
            
            if ret != 0:
                raise RuntimeError(f"Failed to perform obfuscation in SGX: {ret}")
            
            # Convert back to PyTorch tensor
            output = torch.tensor(output_np, dtype=hidden_states.dtype, device=hidden_states.device)
            
            return output
            
        except Exception as e:
            logger.error(f"Error performing obfuscation in SGX: {e}")
            raise RuntimeError(f"Error performing obfuscation in SGX: {e}")
        
    def perform_otp(self, hidden_states):
        """
        Perform otp(one-time-pad) operation in SGX
        
        Args:
            hidden_states: Tensor of shape (batch_size, logits_to_keep, hidden_size)
        
        Returns:
            Obfuscated tensor of shape (batch_size, logits_to_keep, hidden_size)
        """
        if self.sgx_lib is None:
            if not init_sgx():
                raise RuntimeError("Failed to initialize SGX")
        
        try:
            # Convert tensors to numpy arrays
            hidden_states_np = hidden_states.cpu().detach().numpy().astype(np.float32)

            batch_size, logits_to_keep, hidden_size = hidden_states_np.shape
            
            # Allocate output buffer
            output_np = np.zeros_like(hidden_states_np, dtype=np.float32)
            
            # Convert to C pointers
            hidden_states_ptr = hidden_states_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            output_ptr = output_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            
            if self.enclave_id is None:
                self.enclave_id = self.sgx_lib.get_enclave_id()
            
            # Call SGX function
            ret = self.sgx_lib.perform_otp(
                self.enclave_id,
                hidden_states_ptr,
                batch_size,
                logits_to_keep,
                hidden_size,
                output_ptr
            )
            
            if ret != 0:
                raise RuntimeError(f"Failed to perform otp in SGX: {ret}")
            
            # Convert back to PyTorch tensor
            output = torch.tensor(output_np, dtype=hidden_states.dtype, device=hidden_states.device)
            
            return output
            
        except Exception as e:
            logger.error(f"Error performing otp in SGX: {e}")
            raise RuntimeError(f"Error performing otp in SGX: {e}")
        
    def perform_logits_recover(self, hidden_states):
        """
        Perform logits recover in SGX
        
        Args:
            hidden_states: Tensor of shape (batch_size, logits_to_keep, hidden_size)
        
        Returns:
            one-hot embedding logits of shape (batch_size, logits_to_keep, hidden_size)
        """
        if self.sgx_lib is None:
            if not init_sgx():
                raise RuntimeError("Failed to initialize SGX")
        
        try:
            # Convert tensors to numpy arrays
            hidden_states_np = hidden_states.cpu().detach().numpy().astype(np.float32)
            hidden_states_ptr = hidden_states_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
            
            # Get dimensions
            batch_size, logits_to_keep, hidden_size = hidden_states_np.shape
            
            # Allocate output buffer
            output_array = (ctypes.c_int64 * (batch_size * logits_to_keep))()
            output_ptr = ctypes.cast(output_array, ctypes.POINTER(ctypes.c_int64))
            
            if self.enclave_id is None:
                self.enclave_id = self.sgx_lib.get_enclave_id()
            
            # Call SGX function
            ret = self.sgx_lib.perform_logits_recover(
                self.enclave_id,
                hidden_states_ptr,
                batch_size,
                logits_to_keep,
                hidden_size,
                output_ptr
            )
            
            if ret != 0:
                raise RuntimeError(f"Failed to perform logits recover in SGX: {ret}")
            
            # Convert output_array to PyTorch tensor and move to the same device as hidden_states
            output = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
            index_tensor = torch.tensor(output_array, dtype=torch.int64).view(batch_size, logits_to_keep, 1).to(hidden_states.device)
            output = output.scatter_(-1, index_tensor, 1.0)
            
            return output
            
        except Exception as e:
            logger.error(f"Error performing logits recover in SGX: {e}")
            raise RuntimeError(f"Error performing logits recover in SGX: {e}")
        
        

    def cleanup(self):
        """
        Cleanup SGX resources
        """
        self.destroy_enclave()

def init_sgx():
    """
    Initialize SGX enclave
    """
    global sgx_instance
    
    if sgx_instance is not None:
        return True
    
    try:
        sgx_instance = SGX()
        return True
        
    except Exception as e:
        logger.error(f"Error initializing SGX: {e}")
        sgx_instance = None
        return False

def cleanup_sgx():
    """
    Cleanup SGX resources
    """
    global sgx_instance
    
    if sgx_instance is not None:
        sgx_instance.destroy_enclave()
        sgx_instance = None

def prepare_input_obf_params(v_list0, indices_list0, v_list, indices_list):
    """
    Initialize input obfuscation parameters in SGX
    
    Args:
        v_list0: List of lists of vectors for inverse obfuscation
        indices_list0: List of indices for inverse obfuscation
        v_list: List of lists of vectors for obfuscation
        indices_list: List of indices for obfuscation
    """
    if sgx_instance is None:
        if not init_sgx():
            return False
    
    return sgx_instance.prepare_input_obf_params(v_list0, indices_list0, v_list, indices_list)

def prepare_otp_params(otp, otp_logits):
    """
    Initialize otp(one-time-pad) parameters in SGX
    
    Args:
        otp: Tensor of shape (batch_size, logits_to_keep, hidden_size)
        otp_logits: Tensor of shape (batch_size, logits_to_keep, vocab_size)
    """
    if sgx_instance is None:
        if not init_sgx():
            return False
    
    return sgx_instance.prepare_otp_params(otp, otp_logits)

def perform_obfuscation(hidden_states, residual, optimized_stage=-1):
    """
    Perform obfuscation operation in SGX
    
    Args:
        hidden_states: Tensor of shape (batch_size, seq_length, hidden_size)
        residual: Tensor of shape (batch_size, seq_length, hidden_size)
        optimized_stage: Integer, >0 for optimized version, -1 for default version
    
    Returns:
        Obfuscated tensor of shape (batch_size, seq_length, hidden_size)
    """
    if sgx_instance is None:
        if not init_sgx():
            raise RuntimeError("Failed to initialize SGX")
    
    return sgx_instance.perform_obfuscation(hidden_states, residual, optimized_stage)

def perform_otp(hidden_states):
    """
    Perform otp operation in SGX
    
    Args:
        hidden_states: Tensor of shape (batch_size, logits_to_keep, hidden_size)
    
    Returns:
        Obfuscated tensor of shape (batch_size, logits_to_keep, hidden_size)
    """
    if sgx_instance is None:
        if not init_sgx():
            raise RuntimeError("Failed to initialize SGX")
    
    return sgx_instance.perform_otp(hidden_states)

def perform_logits_recover(hidden_states):
    """
    Perform logits recover in SGX
    
    Args:
        hidden_states: Tensor of shape (batch_size, logits_to_keep, hidden_size)
        
    Returns:
        one-hot embedding logits of shape (batch_size, logits_to_keep, hidden_size)
    """
    if sgx_instance is None:
        if not init_sgx():
            raise RuntimeError("Failed to initialize SGX")
    
    return sgx_instance.perform_logits_recover(hidden_states)

def cleanup():
    """
    Cleanup SGX resources
    """
    cleanup_sgx()

# Register cleanup function
atexit.register(cleanup)