import logging
import os
import sys

os.environ['CUDA_VISIBLE_DEVICES'] = '6'

# Setup logging
logging.basicConfig(
    format="%(asctime)s|%(name)s:%(lineno)s|%(levelname)s - %(message)s",
    datefmt="%Y/%m/%d %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
    level=logging.INFO,
)

import torch

from aoeb_ng.task import RetrievalTask
from aoeb_ng.model import SentenceTransformerEmbedder, SentenceTransformerMultimodalEmbedder, st_text_length
from aoeb_ng.models.jinav4 import JinaV4Wrapper
from aoeb_ng.models.vlm2vec import VLM2VecWrapper
from aoeb_ng.models.vista import vista_loader



def main():
    # task_config = '/data8//aoeb/aoeb_ng/tasks/FreshStackDocAngular.json'
    # task_config = '/data8//aoeb/aoeb_ng/tasks/MMRC.json'
    # task_config = '/data8//aoeb/aoeb_ng/tasks/MRMR-design.json'
    task_config = '/data8//aoeb/aoeb_ng/tasks/MMRC-session.json'
    task = RetrievalTask.from_config(task_config)
    device = 'cuda' if torch.cuda.device_count() == 1 else 'cpu'

    # model = SentenceTransformerEmbedder("/data8/xxx/pretrained_models/bge-large-en-v1.5", device=device)
    # model.model = model.model.half()

    # model = JinaV4Wrapper("/data8/xxx/pretrained_models/jina-embeddings-v4")
    # model.model.task = 'retrieval'

    model = SentenceTransformerMultimodalEmbedder("/data8/xxx/pretrained_models/gme-Qwen2-VL-2B-Instruct/", device=device)
    model.model._text_length = st_text_length

    # model = SentenceTransformerMultimodalEmbedder(
    #     "/data8/xxx/pretrained_models/mmE5-mllama-11b-instruct/", trust_remote_code=True, device=device
    # )  # TODO: need instruction
    # model.model._text_length = st_text_length

    # model = VLM2VecWrapper(
    #     "/data8/xxx/pretrained_models/VLM2Vec-LoRA/",
    #     base_model_name="/data8/xxx/pretrained_models/Phi-3.5-vision-instruct",
    #     device=device,
    # )

    # model = vista_loader(
    #     "/data8/xxx/pretrained_models/bge-base-en-v1.5/",
    #     model_weight="/data8/xxx/pretrained_models/bge-visualized/Visualized_base_en_V1.5.pth",
    #     image_tokens_num=196,
    # )

    task.evaluate(model, "results/gme-2b", encode_kwargs={'batch_size': 2})
    print("ok.")



if __name__ == "__main__":
    main()

"""
pip install 'datasets>=4.4' pytrec_eval
"""