# # eval_mme_minigpt4.py
# # Evaluate MiniGPT-4 on MME dataset (MME folder structure: each subtask folder contains images and matching .txt files)

# import argparse
# import os
# import json
# import re
# from PIL import Image
# from tqdm import tqdm

# import torch
# from transformers import StoppingCriteriaList

# from minigpt4.common.config import Config
# from minigpt4.common.registry import registry
# from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub

# # load model components (ensure your PYTHONPATH includes MiniGPT-4)
# from minigpt4.datasets.builders import *
# from minigpt4.models import *
# from minigpt4.processors import *
# from minigpt4.runners import *
# from minigpt4.tasks import *


# def parse_args():
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--cfg-path", default="/root/autodl-tmp/MiniGPT-4/eval_configs/minigpt4_eval.yaml",
#                         help="path to config file")
#     parser.add_argument("--mme-path", type=str, default="/root/autodl-tmp/dataset/MME",
#                         help="Path to MME root directory (contains subtask folders)")
#     parser.add_argument("--output-file", default="minigpt4_mme_results.jsonl")
#     parser.add_argument("--gpu-id", type=int, default=0)
#     # parser.add_argument("--K", type=int, default=None, help="(optional) which layer to prune - injected to model config if present")
#     # parser.add_argument("--retain-ratio", type=float, default=None, help="(optional) retain ratio - injected to model config if present")
#     # parser.add_argument("--method", type=str, default=None, help="(optional) method - injected to model config if present")
#     parser.add_argument(
#         "--options",
#         nargs="+",
#         default=[],
#         help="override some settings in the used config"
#     )
#     return parser.parse_args()


# def load_json_or_jsonl(path):
#     with open(path, "r", encoding="utf-8") as f:
#         first_line = f.readline().strip()
#         f.seek(0)
#         if first_line.startswith("["):
#             return json.load(f)
#         else:
#             return [json.loads(line.strip()) for line in f if line.strip()]


# def parse_txt_file(txt_path):
#     """
#     Parse txt to two (question, answer) pairs.
#     Same logic as the instructblip script: split at last '.' per line.
#     """
#     questions, answers = [], []
#     with open(txt_path, 'r', encoding='utf-8') as f:
#         lines = f.readlines()
#         for line in lines:
#             line = line.strip()
#             if not line:
#                 continue
#             last_dot = line.rfind('.')
#             if last_dot != -1 and last_dot + 1 < len(line):
#                 question = line[:last_dot + 1].strip()
#                 answer = line[last_dot + 1:].strip()
#             else:
#                 # fallback: split by first tab or ' ' if not conforming
#                 parts = re.split(r'\t+', line, maxsplit=1)
#                 if len(parts) == 2:
#                     question, answer = parts[0].strip(), parts[1].strip()
#                 else:
#                     # last resort: treat whole line as question, empty answer
#                     question, answer = line, ""
#             questions.append(question)
#             answers.append(answer)
#     return questions, answers


# def clean_and_append_instruction(q):
#     # remove "Please answer yes or no..." and similar suffixes, case-insensitive
#     q = re.sub(r'Please answer yes or no.*$', '', q, flags=re.IGNORECASE).strip()
#     q = q.rstrip('.')
#     # append instruction
#     q = q + ". Answer the question using yes or no."
#     return q


# def run():
#     args = parse_args()
#     cfg = Config(args)

#     conv_dict = {
#         'pretrain_vicuna0': CONV_VISION_Vicuna0,
#         'pretrain_llama2': CONV_VISION_LLama2,
#     }

#     model_config = cfg.model_cfg
#     # device assignment: using cuda:<gpu-id>
#     device_str = f'cuda:{args.gpu_id}'
#     model_config.device_8bit = args.gpu_id

#     model_cls = registry.get_model_class(model_config.arch)
#     model = model_cls.from_config(model_config).to(device_str)

#     # if optional config injection requested (keep safe checks)
#     try:
#         if args.K is not None and hasattr(model, "llm_model") and hasattr(model.llm_model, "config"):
#             setattr(model.llm_model.config, "K", args.K)
#         if args.retain_ratio is not None and hasattr(model, "llm_model") and hasattr(model.llm_model, "config"):
#             setattr(model.llm_model.config, "retain_ratio", args.retain_ratio)
#         if args.method is not None and hasattr(model, "llm_model") and hasattr(model.llm_model, "config"):
#             setattr(model.llm_model.config, "method", args.method)
#     except Exception as e:
#         print(f"[WARN] failed to inject optional config params: {e}")

#     conv_template = conv_dict.get(model_config.model_type, CONV_VISION_Vicuna0)

#     # get visual processor (use eval/val processor if available)
#     vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
#     vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)

#     # stopping criteria - keep as in original
#     stop_words_ids = [[835], [2277, 29937]]
#     stop_words_ids = [torch.tensor(ids).to(device_str) for ids in stop_words_ids]
#     stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

#     chat = Chat(model, vis_processor, device=device_str, stopping_criteria=stopping_criteria)

#     mme_root = args.mme_path
#     subtasks = sorted([d for d in os.listdir(mme_root) if os.path.isdir(os.path.join(mme_root, d))])
#     all_results = []

#     for subtask in subtasks:
#         subtask_path = os.path.join(mme_root, subtask)
#         files = sorted(os.listdir(subtask_path))
#         subtask_results = []

#         for file in files:
#             if not (file.lower().endswith(".jpg") or file.lower().endswith(".png")):
#                 continue

#             base_name = os.path.splitext(file)[0]
#             image_path = os.path.join(subtask_path, file)
#             txt_path = os.path.join(subtask_path, base_name + ".txt")

#             if not os.path.exists(txt_path):
#                 print(f"Warning: Missing .txt for {file} in {subtask}. Skipping.")
#                 continue

#             try:
#                 questions, gts = parse_txt_file(txt_path)
#                 # expect two questions/answers
#                 if not (len(questions) >= 2 and len(gts) >= 2):
#                     print(f"Failed to parse {txt_path}: expecting at least 2 lines, got {len(questions)}")
#                     continue
#             except Exception as e:
#                 print(f"Failed to parse {txt_path}: {e}")
#                 continue

#             # clean and append instructions like instructblip script
#             q1 = clean_and_append_instruction(questions[0])
#             q2 = clean_and_append_instruction(questions[1])

#             # load image and run two inferences
#             try:
#                 image = Image.open(image_path).convert("RGB")
#             except Exception as e:
#                 print(f"Failed to open image {image_path}: {e}")
#                 continue

#             conv = conv_template.copy()
#             chat_state = conv
#             img_list = []

#             # We will run two separate queries per image
#             # try:
#                 # Upload and encode image once
#             chat.upload_img(image, chat_state, img_list)
#             chat.encode_img(img_list)

#                 # Q1
#             chat.ask(q1, chat_state)
#             pred1 = chat.answer(
#                     conv=chat_state,
#                     img_list=img_list,
#                     num_beams=1,
#                     temperature=0.001,
#                     max_new_tokens=5,
#                     max_length=2000
#                 )
#             pred1 = pred1[0] if isinstance(pred1, (list, tuple)) else pred1

#                 # Q2
#             chat.ask(q2, chat_state)
#             pred2 = chat.answer(
#                     conv=chat_state,
#                     img_list=img_list,
#                     num_beams=1,
#                     temperature=0.001,
#                     max_new_tokens=5,
#                     max_length=2000
#                 )
#             pred2 = pred2[0] if isinstance(pred2, (list, tuple)) else pred2

#             # except Exception as e:
#             #     print(f"[ERROR] inference error for {image_path}: {e}")
#             #     pred1, pred2 = "ERROR", "ERROR"

#             # ground truths may contain extra info; take first two
#             label_1 = gts[0] if len(gts) >= 1 else ""
#             label_2 = gts[1] if len(gts) >= 2 else ""

#             subtask_results.append({
#                 "image": file,
#                 "label_1": label_1,
#                 "pred_1": pred1,
#                 "label_2": label_2,
#                 "pred_2": pred2
#             })

#             qid = f"{subtask}/{base_name}"
#             print(f"[{qid}] Q1: {q1} | GT1: {label_1} | Pred1: {pred1}")
#             print(f"[{qid}] Q2: {q2} | GT2: {label_2} | Pred2: {pred2}")

#         all_results.append({subtask: subtask_results})

#     # Save results as jsonl: one line per subtask (dict)
#     with open(args.output_file, 'w', encoding='utf-8') as f:
#         for r in all_results:
#             f.write(json.dumps(r, ensure_ascii=False) + "\n")

#     print(f"Saved results to {args.output_file}")


# if __name__ == "__main__":
#     run()

# eval_mme_minigpt4_simple.py
# Minimal conversion: run MiniGPT-4 on MME dataset (same simplicity as your instructblip script)

import argparse
import json
import os
import re
from PIL import Image
from tqdm import tqdm

import torch
from transformers import StoppingCriteriaList

from minigpt4.common.config import Config
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub

# Load model components (ensure your PYTHONPATH includes MiniGPT-4)
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *


def load_image(image_path):
    return Image.open(image_path).convert("RGB")


def run_single_inference(chat, conv_template, question, image_path, device_str):
    raw_image = load_image(image_path)
    conv = conv_template.copy()
    chat_state = conv
    img_list = []
    chat.upload_img(raw_image, chat_state, img_list)
    chat.encode_img(img_list)
    chat.ask(question, chat_state)
    output = chat.answer(
        conv=chat_state,
        img_list=img_list,
        num_beams=1,
        temperature=0.000001,
        max_new_tokens=1,
        max_length=2000
    )
    return output[0] if isinstance(output, (list, tuple)) else output


def parse_txt_file(txt_path):
    questions, answers = [], []
    with open(txt_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            last_dot = line.rfind('.')
            if last_dot != -1 and last_dot + 1 < len(line):
                question = line[:last_dot + 1].strip()
                answer = line[last_dot + 1:].strip()
                questions.append(question)
                answers.append(answer)
    return questions, answers


def run_minigpt4_mme(args):
    if not hasattr(args, "options"):
        args.options = []   # 兼容 MiniGPT-4 Config
    cfg = Config(args)


    conv_dict = {
        'pretrain_vicuna0': CONV_VISION_Vicuna0,
        'pretrain_llama2': CONV_VISION_LLama2,
    }

    model_config = cfg.model_cfg
    device_str = f'cuda:{args.gpu_id}' if torch.cuda.is_available() else "cpu"
    model_config.device_8bit = args.gpu_id
    model_cls = registry.get_model_class(model_config.arch)
    model = model_cls.from_config(model_config).to(device_str)

    # optional injection if model has llm_model.config
    if args.K is not None:
        try:
            model.llama_model.config.K = args.K
        except Exception:
            pass
    if args.retain_ratio is not None:
        try:
            model.llama_model.config.retain_ratio = args.retain_ratio
        except Exception:
            pass
    if args.method is not None:
        try:
            model.llama_model.config.method = args.method
        except Exception:
            pass

    conv_template = conv_dict.get(model_config.model_type, CONV_VISION_Vicuna0)

    # vis processor from config (keep as in original)
    vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
    vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)

    # stopping criteria kept same as original (but Chat will be created without None)
    stop_words_ids = [[835], [2277, 29937]]
    stop_words_ids = [torch.tensor(ids).to(device_str) for ids in stop_words_ids]
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

    chat = Chat(model, vis_processor, device=device_str, stopping_criteria=stopping_criteria)

    results = []
    subtasks = sorted([d for d in os.listdir(args.mme_path) if os.path.isdir(os.path.join(args.mme_path, d))])
    for subtask in subtasks:
        subtask_path = os.path.join(args.mme_path, subtask)
        data_files = sorted(os.listdir(subtask_path))
        subtask_results = []

        for file in data_files:
            if not (file.endswith(".jpg") or file.endswith(".png")):
                continue

            base_name = os.path.splitext(file)[0]
            image_path = os.path.join(subtask_path, file)
            txt_path = os.path.join(subtask_path, base_name + ".txt")

            if not os.path.exists(txt_path):
                print(f"Warning: Missing .txt for {file}")
                continue

            try:
                questions, gts = parse_txt_file(txt_path)
                assert len(questions) == 2 and len(gts) == 2
            except Exception as e:
                print(f"Failed to parse {txt_path}: {e}")
                continue

            # same cleaning as your instructblip script
           
            questions[0] = "Answer the question using yes or no."
            
            questions[1] = "Answer the question using a yes or no."
            
            

            pred_1 = run_single_inference(chat, conv_template, questions[0], image_path, device_str)
            pred_2 = run_single_inference(chat, conv_template, questions[1], image_path, device_str)

            # print(f"Q1: {questions[0]} | GT1: {gts[0]} | Pred1: {pred_1}")
            # print(f"Q2: {questions[1]} | GT2: {gts[1]} | Pred2: {pred_2}")

            subtask_results.append({
                "label_1": gts[0],
                "pred_1": pred_1,
                "label_2": gts[1],
                "pred_2": pred_2
            })

        results.append({subtask: subtask_results})

    # Save results
    with open(args.output_file, "w", encoding="utf-8") as f:
        for r in results:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cfg-path", default="/root/autodl-tmp/MiniGPT-4/eval_configs/minigpt4_eval.yaml", help="path to config file")
    parser.add_argument("--mme-path", type=str, default="/root/autodl-tmp/dataset/MME", help="Path to MME data")
    parser.add_argument("--output-file", type=str, default="minigpt4_mme_results.jsonl")
    parser.add_argument("--gpu-id", type=int, default=0)
    parser.add_argument("--K", type=int, default=13, help="which layer to prune")
    parser.add_argument("--retain_ratio", type=float, default=1, help="retaining ratio")
    parser.add_argument("--method", type=str, default="our")
    args = parser.parse_args()

    run_minigpt4_mme(args)
