from preference_dataset.generate_preference_dataset import GeneratePere
from datasets_ import Animals, RealworldQA, MMBench, MMStar, SeedBench, ScienceQA
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import torch

if __name__ == '__main__':


    # data = Animals(root='data/Animals_with_Attributes2', test=False, return_image_path=True)
    # data = RealworldQA(root='data/realworldQA/data', return_image_path=True)
    # data = MMBench(root='data/mmbench/data', return_image_path=True)
    # data = MMStar(root='data/mmstar', return_image_path=True)
    # data = SeedBench(root='data/seedbench/data', return_image_path=True)
    data = ScienceQA(root='data/science/data', return_image_path=True)

    min_pixels = 256 * 28 * 28
    max_pixels = 512 * 28 * 28

    print('load model')
    vlm = Qwen2VLForConditionalGeneration.from_pretrained(
        'Qwen2-VL-2B/Qwen2-VL-2B-Instruct',
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    vlm_processor = AutoProcessor.from_pretrained('Qwen2-VL-2B/Qwen2-VL-2B-Instruct',
                                                  min_pixels=min_pixels,
                                                  max_pixels=max_pixels)

    print('done')

    g = GeneratePere(original_data=data,
                     batch_size=1,
                     data_save_path='preference_json_data/science_new',
                     vlm_processor=vlm_processor,
                     vlm=vlm,
                     mu_path='mu/science',
                     activate_vector_path='activate_vector/science_preference',
                     checkpoint_path='lrm_science.pth')

    g.generate_preference_dataset()

