from typing import Optional, Tuple, Union, List, Dict

import os
import torch
import torch.nn as nn
from einops import rearrange
import time

from transformers import PreTrainedModel
from PIL import Image

from mllmsd.utils.custom_files.conversation import get_conv_template

class MLLM(nn.Module):
    """
    Wrapper for MLLM/LLM model to support various models
    Following the same interface as ConditionalGeneration classes in transformers
    """
    def __init__(
        self, 
        model: PreTrainedModel,
    ):
        super(MLLM, self).__init__()
        self.mllm = model.cuda().eval()
    
    def generate(self, *args, **kwargs):
        return self.mllm.generate(*args, **kwargs)