import torch
from torch import nn

from reward_model.VideoAlign.inference import VideoVLMRewardInference

class VideoAlignWrapper(torch.nn.Module):
    def __init__(self, load_from_pretrained="", device=torch.cuda.current_device(), dtype = torch.bfloat16):
        super().__init__()

        self.inferencer = VideoVLMRewardInference(load_from_pretrained, device=device, dtype=dtype)

        # self.inferencer.eval()
        import gc
        gc.collect()          # 强制完整垃圾回收一次