import numpy as np
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

class TRTInference:
    def cleanup(self):
        self.ctx.pop() 
        del self.context  
        del self.engine  
        # Finally handle CUDA context
        self.ctx.detach() 

    def __del__(self):
        self.cleanup()
    def __init__(self, engine_path: str, batch_size: int = 1, gpu_id: int = 0):
        self.engine_path = engine_path
        self.batch_size = batch_size
        self.gpu_id = gpu_id
        # Switch to specified GPU
        self.ctx = cuda.Device(self.gpu_id).make_context()
        self._load_engine()

    def _load_engine(self):
        TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
        with open(self.engine_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
            self.engine = runtime.deserialize_cuda_engine(f.read())
        self.context = self.engine.create_execution_context()
        self.input_name = self.engine.get_tensor_name(0)
        self.output_name = self.engine.get_tensor_name(1)
        self.input_dtype = trt.nptype(self.engine.get_tensor_dtype(self.input_name))
        self.output_dtype = trt.nptype(self.engine.get_tensor_dtype(self.output_name))
        # input_shape/output_shape may contain -1, get dynamically during inference

    def infer(self, batch: np.ndarray) -> np.ndarray:
        self.ctx.push()
        batch = batch.astype(self.input_dtype)
        actual_batch = batch.shape[0]
        # Dynamic shape inference: set context shape
        self.context.set_input_shape(self.input_name, batch.shape)
        input_shape = self.context.get_tensor_shape(self.input_name)
        output_shape = self.context.get_tensor_shape(self.output_name)
        d_input = cuda.mem_alloc(trt.volume(input_shape) * np.dtype(self.input_dtype).itemsize)
        d_output = cuda.mem_alloc(trt.volume(output_shape) * np.dtype(self.output_dtype).itemsize)
        bindings = [int(d_input), int(d_output)]
        cuda.memcpy_htod(d_input, batch)
        self.context.execute_v2(bindings)
        output = np.empty(output_shape, dtype=self.output_dtype)
        cuda.memcpy_dtoh(output, d_output)
        self.ctx.pop()
        return output
