# from typing import Optional, Tuple, Union, List, Dict

# import os
# import torch
# import torch.nn as nn
# from einops import rearrange
# import time

# from transformers import AutoModel
# from PIL import Image

# from mllmsd.utils.custom_files.conversation import get_conv_template

# class MLLM(nn.Module):
#     """
#     Wrapper for MLLM/LLM model to support various models
#     Following the same interface as ConditionalGeneration classes in transformers
#     """
#     def __init__(
#         self, 
#         model,
#     ):
#         super(MLLM, self).__init__()
#         self.mllm = model.cuda().eval()
    
#     def generate(
#             self,
#             input_ids: Optional[torch.Tensor] = None,
#             attention_mask: Optional[torch.Tensor] = None,
#             pixel_values: Optional[torch.Tensor] = None,
#             **kwargs,
#         ) -> Dict:

#         # Todo: 매번 올릴 필요없이 inputs를 올리는 것도...
#         # load inputs to device
#         inputs = dict(
#             input_ids = input_ids.to(self.mllm.device) if input_ids is not None else None,
#             attention_mask = attention_mask.to(self.mllm.device) if attention_mask is not None else None,
#         )
#         if pixel_values is not None:
#             inputs['pixel_values'] = pixel_values.to(self.mllm.device)


#         # [MODEL]ForConditionalGeneration.generate
#         outputs_generate = self.mllm.generate(
#             **inputs,
#             use_cache=True,
#             output_logits=True,
#             output_hidden_states=False,
#             return_dict_in_generate=True, 
#             **kwargs,
#         )

#         return outputs_generate

#     def forward(
#         self,
#         input_ids: Optional[torch.Tensor] = None,
#         inputs_embeds: Optional[torch.Tensor] = None,
#         attention_mask: Optional[torch.Tensor] = None,
#         position_ids: Optional[torch.LongTensor] = None,
#         pixel_values: Optional[torch.Tensor] = None,
#         past_key_values: Optional[List[torch.FloatTensor]] = None,
#     ) -> Dict:

#         # load inputs to device
#         inputs = dict(
#             input_ids = input_ids.to(self.mllm.device) if input_ids is not None else None,
#             inputs_embeds = inputs_embeds.to(self.mllm.device) if inputs_embeds is not None else None,
#             attention_mask = attention_mask.to(self.mllm.device) if attention_mask is not None else None,
#             position_ids = position_ids.to(self.mllm.device) if position_ids is not None else None,
#         )
#         if pixel_values is not None:
#             inputs['pixel_values'] = pixel_values.to(self.mllm.device) if pixel_values is not None else None
        
#         # [MODEL]ForConditionalGeneration.forward
#         outputs_forward = self.mllm(
#             **inputs,
#             past_key_values=past_key_values,
#             use_cache=True,
#         )

#         return outputs_forward