
import pandas as pd
from tqdm import tqdm  # 进度条工具
from Models.mPLUG import mPLUG
from utils.ECE import *
from utils.utils import *
import json
import argparse
import os
from utils.conversation import conv_templates, DEFAULT_IMAGE_TOKEN
import os, json
import pandas as pd
from tqdm import tqdm
from utils.conversation import conv_templates, DEFAULT_IMAGE_TOKEN
from collections import Counter

def eval_model(data, model, args, add_image=True):
    all_results = []

    # data = data.iloc[98:100]
    for _, row in tqdm(data.iterrows(), desc="Evaluating", total=len(data)):
        qs = row.get('question', '')
        context = ' Please use a single-word or phrase to answer.'
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + context

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        image_bytes = row['image']["bytes"]
        model.generate_prompt(prompt)

        response = model.get_answer(image_bytes)
        if response is None:
            continue  
        try:
            input_data = response[1]
        except (TypeError, IndexError):
            continue

        dataset_name = args.dataset_name
        if dataset_name == "VQAv2":
            gt_answer = row.get('multiple_choice_answer', '')
            if not gt_answer:
                print("Ground truth is none for VQAv2")
                continue
        elif dataset_name in ["OKVQA", "VizWiz", "TextVQA"]:
            gt_answers = row.get('answers', '')
            if len(gt_answers) > 0:
                gt_answer = Counter(gt_answers).most_common(1)[0][0]
            else:
                print("Ground truth is none for", dataset_name)
                continue
        elif dataset_name == "MMBench":
            gt_answer = row.get('answer', '')
            if not gt_answer:
                print("Ground truth is none for MMBench")
                continue
        else:
            print("Ground truth extraction rule not defined for", dataset_name)
            continue

        gt_answer = gt_answer.strip().lower()

        try:
            nll_score, brier_score, MacroCE_score, _, _ = model.get_word_nll_and_brier(gt_answer, input_data)
        except Exception as e:
            print(f"[WARN] NLL/Brier computation failed: {gt_answer} | {e}")
            nll_score, brier_score, MacroCE_score = -1, -1, -1

        all_results.append({
            'question': qs,
            'gt_answer': gt_answer,
            'nll': nll_score,
            'brier': brier_score,
            'MacroCE': MacroCE_score,
        })

    save_dir = f'./result/nll_6/{dataset_name}-{args.split}/{model.name}/'
    os.makedirs(save_dir, exist_ok=True)
    with open(f'{save_dir}/answer.json', 'w', encoding='utf-8') as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)

    valid_nll = [r['nll'] for r in all_results if r['nll'] >= 0]
    valid_brier = [r['brier'] for r in all_results if r['brier'] >= 0]
    valid_MacroCE = [r['MacroCE'] for r in all_results]
    sum_nll = sum(valid_nll)
    sum_brier = sum(valid_brier)
    sum_MacroCE = sum(valid_MacroCE)
    avg_nll = sum_nll / len(all_results)
    avg_brier = sum_brier / len(all_results)
    avg_MacroCE = sum_MacroCE / len(all_results)


    report = {
        'model': model.name,
        'dataset': dataset_name,
        'split': args.split,
        'nll': avg_nll,
        'brier': avg_brier,
        'MacroCE':avg_MacroCE,
        'sum_nll': sum_nll,
        'sum_brier': sum_brier,
        'sum_MacroCE': sum_MacroCE,
        'sample_count': len(all_results),
    }

    report_path_map = {
        "VQAv2": "./result/nll_6/vqa_nll.csv",
        "VizWiz": "./result/nll_6/viz_nll.csv",
        "OKVQA": "./result/nll_6/okvqa_nll.csv",
        "TextVQA": "./result/nll_6/textvqa_nll.csv",
        "MMBench": "./result/nll_6/mmbench_nll.csv",
    }
    report_path = report_path_map.get(dataset_name, f"./{dataset_name}_nll_report.csv")

    pd.DataFrame([report]).to_csv(
        report_path,
        mode='a',
        header=not os.path.exists(report_path),
        index=False
    )
    print(f"✅ Report saved to: {report_path}")

def main(args):

    
    model = mPLUG(args)

    selected_files = [
        os.path.join(args.data_path, f) for f in os.listdir(args.data_path)
        if f.endswith('.parquet') and f.startswith(args.split)
    ]

    data = pd.concat([pd.read_parquet(file) for file in selected_files], ignore_index=True)

    data = data.sample(n=args.num_test_samples, random_state=42).reset_index(drop=True)
    eval_model(data, model, args, add_image=True)
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="NLL and Brier evaluation")
    parser.add_argument('--model_path', type=str,help="Path to the pre-trained or distilled model file. Specify the location of the model.")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--data_path', type=str, help="Path to the dataset.")
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--num_test_samples', type=int, default=1)
    parser.add_argument('--dataset_name', type=str, help="Dataset Name, serve as savedir")
    parser.add_argument('--split', type=str, default="val",help="The Part of Dataset")
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=128)
    parser.add_argument("--conv_mode", default="llava_llama_2",help="The system prompt of model")



    args = parser.parse_args()

    main(args)
