import json, os
import sys
sys.path.append(('../'))
sys.path.append(('../../'))

# load data
import json
import argparse
import importlib
from utils.infer_on_data import *
# from models.load_LLaVA import *

mllm_to_module = {
    "gpt4": "load_GPT4o",
    "llava": "load_LLaVA",
    "instructblip": "load_instructblip",
    "idefics": "load_idefics",
    "deepseek": "load_deepseek",
    "qwenvl_7b_instruct": "load_Qwen_VL_7B_Instruct",
}

# args
parser = argparse.ArgumentParser()
parser.add_argument("--mllm", type=str, default="llava", choices=mllm_to_module.keys())
parser.add_argument("--data_root", type=str, default='MSSBENCH_DATA')
parser.add_argument("--output_dir", type=str, default='OUTPUT_DIR')
args = parser.parse_args()

# Dynamic import based on mllm argument
module_name = f"models.{mllm_to_module[args.mllm]}"
model_module = importlib.import_module(module_name)
globals().update(vars(model_module))

val_data = json.load(open(os.path.join(args.data_root, "combined.json"), 'r'))

# c_safe_acc, c_unsafe_acc, c_total_acc, e_safe_acc, e_unsafe_acc, e_total_acc = \
#     test_each_mss(val_data, call_model, args.data_root, output_path=os.path.join(args.output_dir, f"{args.mllm}_mssbench.json"))

test_each_mss(val_data, call_model, args.data_root, output_path=os.path.join(args.output_dir, f"{args.mllm}_mssbench.json"))

with open(os.path.join(args.output_dir, f"{args.mllm}_mssbench.json"), 'r') as f:
    responses = json.load(f)

# Make sure this is correct — adapt the filename if needed
save_file = os.path.join(args.output_dir, f"{args.mllm}_mssbench_eval.json")

# Ensure directory exists for the eval file
os.makedirs(os.path.dirname(save_file), exist_ok=True)

# Now directly call gpt4_eval on loaded responses
c_safe_acc, c_unsafe_acc, c_total_acc, e_safe_acc, e_unsafe_acc, e_total_acc = gpt4_eval(responses, save_file)

    
print(f"Chat Safe Acc: {c_safe_acc}, Chat Unsafe Acc: {c_unsafe_acc}, Chat Total Acc: {c_total_acc}")
print(f"Embodied Safe Acc: {e_safe_acc}, Embodied Unsafe Acc: {e_unsafe_acc}, Embodied Total Acc: {e_total_acc}")