import argparse
import os
from tqdm import tqdm
import pandas as pd
from evaluator.NLL.VQAv2 import eval_model
import pandas as pd
from utils.ECE import *
from utils.utils import *
import json
from main.Ask import get_model_class
from utils.prompt import *


def main(args):

    ModelClass = get_model_class(args.model_path)
    model = ModelClass(args)

    if args.dataset_name in ['MMBench']:
        data = pd.read_parquet(args.data_path)
    else:
        selected_files = [
        os.path.join(args.data_path, f) for f in os.listdir(args.data_path)
        if f.endswith('.parquet') and f.startswith(args.split)
        ]
        data = pd.concat([pd.read_parquet(file) for file in selected_files], ignore_index=True)

    data = data.sample(n=args.num_test_samples, random_state=42).reset_index(drop=True)
    eval_model(data, model, args, add_image=True)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="NLL and Brier evaluation")
    parser.add_argument('--model_path', type=str,help="Path to the pre-trained or distilled model file. Specify the location of the model.")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--data_path', type=str, help="Path to the dataset.")
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--num_test_samples', type=int, default=1000,help="Number of test samples.")
    parser.add_argument('--dataset_name', type=str, help="Dataset Name, serve as savedir")
    parser.add_argument('--split', type=str, default="val",help="The Part of Dataset")
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=128)
    parser.add_argument("--conv_mode", default="llava_llama_2",help="The system prompt of model")



    args = parser.parse_args()

    main(args)
