import os
import torch
import types
from statistics import mean

from easyeditor import BaseEditor, MultimodalTrainer, MultimodalEditor
from easyeditor import CaptionDataset, VQADataset, CompositionalCaptionDataset
from easyeditor import MENDMultimodalTrainingHparams, SERACMultimodalTrainingHparams, IKEMultimodalHyperParams, MENDMultimodalHparams \
    , SERACMultimodalHparams, FTMultimodalHparams, OURSMultimodalHparams
from easyeditor import encode_ike_facts_multimodal
from sentence_transformers import SentenceTransformer
import sys
from datetime import datetime



def test_LLaVA_OURS_comp_500_82_xx():
    hparams = OURSMultimodalHparams.from_hparams('hparams/OURS/0511/llava_test_comp500_OURS_82_xx.yaml')
    eval_ds = CompositionalCaptionDataset(eval_comp_new_json_path, config=hparams, hop=hop)
    trainer = MultimodalTrainer(
        config=hparams,
        train_set=eval_ds,
        val_set=eval_ds
    )
    trainer.test_sequencial(log=True, gap_num=gap_num, test_num=500, comp=True)

def test_MiniGPT4_OURS_comp_500_82_xx():
    hparams = OURSMultimodalHparams.from_hparams('hparams/OURS/0511/minigpt4_test_comp500_OURS_82_xx.yaml')
    eval_ds = CompositionalCaptionDataset(eval_comp_new_json_path, config=hparams, hop=hop)
    trainer = MultimodalTrainer(
        config=hparams,
        train_set=eval_ds,
        val_set=eval_ds
    )
    trainer.test_sequencial(log=True, gap_num=gap_num, test_num=500, comp=True)






if __name__ == "__main__":
    function_name = sys.argv[1]
    hop = 1

    # Outdated
    eval_comp_json_path = 'NO'


    train_json_path = 'datasets/train.json'
    train_comp_json_path = 'datasets/train_comp.json'
    eval_json_path = 'datasets/eval_multihop.json'
    eval_comp_new_json_path = 'datasets/eval_comp_0411.json' # 내가 검토후 데이터셋
    if function_name not in globals() or not callable(globals()[function_name]):
        print(f"Error: Function '{function_name}' does not exist.")
        sys.exit(1)

    if function_name == 'test_LLaVA_OURS_comp_500':
        for gap_num in [50, 100]:
            print('test_LLaVA_OURS_comp_500')
            globals()[function_name]()
    else:
        for gap_num in [0, 10, 20, 50, 100]:
            globals()[function_name]()
    
    
