import argparse
import json
import os
import sys
import openai
from openai import OpenAI
sys.path.append(('../'))
sys.path.append(('../../'))
from method.eval import test_each_mss, eval_mss_bench, eval_moss_bench
from method.load_models import load_model, format_conversation, call_model_wrapper

openai.api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI()


def parse_args():
    parser = argparse.ArgumentParser(description="Contrastive Safety-Aware Decoding for Multimodal Models")

    parser.add_argument("--model_type", type=str, default="llava", help="Model type: llava, llava-next, etc.")
    parser.add_argument("--model_path", type=str, required=True, help="Path or HF hub ID for model")
    parser.add_argument("--output_name", type=str, required=True, help="Saved name for output files")
    parser.add_argument("--alpha", type=float, default=0.3, help="Contrastive scaling factor")
    parser.add_argument("--max_steps", type=int, default=5, help="Max guided decoding steps")
    parser.add_argument("--top_k", type=int, default=20, help="Top-k tokens to consider")
    parser.add_argument("--lambda_supp", type=float, default=1.0, help="Suppression strength for safe verdicts")
    parser.add_argument("--lambda_boost", type=float, default=1.0, help="Boost strength for unsafe verdicts")
    parser.add_argument("--total_max_new_tokens", type=int, default=256, help="Top-k tokens to consider")
    parser.add_argument("--mss_data_root", type=str, default= None )
    parser.add_argument("--mss_output_dir", type=str, default=None)
    parser.add_argument("--moss_data_root", type=str, default=None)
    parser.add_argument("--moss_output_dir", type=str, default=None)
    parser.add_argument("--moss_data_list", nargs='+', help="Specify the data samples to be run")
    parser.add_argument("--moss_data_offset", type=int, default=0)
    parser.add_argument("--moss_inference", default=True, action=argparse.BooleanOptionalAction)
    parser.add_argument("--moss_eval", default=True, action=argparse.BooleanOptionalAction)
    return parser.parse_args()



def main():
    pass

if __name__ == "__main__":

    args = parse_args()
    model, processor, tokenizer = load_model(args)
    if args.mss_data_root is not None and args.mss_output_dir is not None:
        print("Running MSSBench evaluation...")
        c_safe_acc, c_unsafe_acc, c_total_acc, e_safe_acc, e_unsafe_acc, e_total_acc = eval_mss_bench(
            args, model, processor, tokenizer
        )
        print(f"MSSBench Chat: safe={c_safe_acc}, unsafe={c_unsafe_acc}, total={c_total_acc}")
        print(f"MSSBench Embodied: safe={e_safe_acc}, unsafe={e_unsafe_acc}, total={e_total_acc}")

    if args.moss_data_root is not None and args.moss_output_dir is not None:
        print("Running MOSSBench evaluation...")
        eval_moss_bench(args, model, processor, tokenizer)

