import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "4"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
import argparse
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import Qwen2VLForConditionalGeneration
from PIL import Image
import torch
from qwen_vl_utils import process_vision_info
import re 
import time
import json
from hico_text_label import MAP_AO_TO_HOI, OBJ_IDX_TO_OBJ_NAME, ACT_IDX_TO_ACT_NAME, RARE_HOI_IDX, HICO_INTERACTIONS, TIME_AMBIGUITY_HOI_INDEX, HOI_TO_AO
import pickle
from torchvision.ops.boxes import batched_nms, box_iou
# from script_hico_evaluation_descrip import merge_json_files
import random 
from transformers import BitsAndBytesConfig
import gc
from newbench_question_func import newbench_interaction_question, mllm_instancef1_eval, mllm_macrof1_eval, generate_candidate_pairs, match_gtbox, newbench_detection_question, parse_detection_answer, newbench_pre_question_imgsize, parse_imgsize_answer
from peft import PeftModel
from typing import List, Dict, Iterable
import string


# 将 0-based 整数转为 Excel 风格字母：0->A, 25->Z, 26->AA, ...
def _index_to_label(i: int) -> str:
    s = []
    i += 1
    while i:
        i, r = divmod(i - 1, 26)
        s.append(chr(65 + r))
    return "".join(reversed(s))

# 将字母转为 0-based 整数：A->0, Z->25, AA->26 ...
def _label_to_index(label: str) -> int:
    n = 0
    for ch in label.upper():
        n = n * 26 + (ord(ch) - 64)
    return n - 1

def _labels(count: int, start: str = "A") -> Iterable[str]:
    start_idx = _label_to_index(start)
    for i in range(start_idx, start_idx + count):
        yield _index_to_label(i)

def label_choices(
    choices: List[str],
    start: str = "A",
    az_only: bool = True,   # True 时只允许 A-Z
) -> Dict[str, str]:
    """
    将任意长度的 choices 列表映射为 {'A': choice0, 'B': choice1, ...}
    - 超过 26 个时：默认采用 AA、AB…；若 az_only=True 则抛错。
    - start 可设定起始字母（如从 'C' 开始）。
    """
    if not choices:
        return {}
    if az_only and len(choices) > 26:
        raise ValueError("选项超过 26 个，但 az_only=True。请关闭 az_only 或减少数量。")
    return dict(zip(_labels(len(choices), start=start), choices))

### resize the image for Qwen processing
def qwen_img_resize(
    height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 12845056
) -> tuple[int, int]:
    """
    Rescales the image so that the following conditions are met:

    1. Both dimensions (height and width) are divisible by 'factor'.

    2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].

    3. The aspect ratio of the image is maintained as closely as possible.
    """
    MAX_RATIO = 200
    if max(height, width) / min(height, width) > MAX_RATIO:
        raise ValueError(
            f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
        )
    h_bar = max(factor, round_by_factor(height, factor))
    w_bar = max(factor, round_by_factor(width, factor))
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = floor_by_factor(height / beta, factor)
        w_bar = floor_by_factor(width / beta, factor)
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = ceil_by_factor(height * beta, factor)
        w_bar = ceil_by_factor(width * beta, factor)
    return h_bar, w_bar

def round_by_factor(number: int, factor: int) -> int:
    """Returns the closest integer to 'number' that is divisible by 'factor'."""
    return round(number / factor) * factor

def ceil_by_factor(number: int, factor: int) -> int:
    """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
    return math.ceil(number / factor) * factor

def floor_by_factor(number: int, factor: int) -> int:
    """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
    return math.floor(number / factor) * factor



def extract_assistant_response(text):
    lines = text.strip().split("\n")
    if len(lines) < 2:
        return text
    if lines[-2].strip() == 'assistant':
        return lines[-1].strip()
    if 'assistant' in lines:
        idx = lines.index('assistant')
        return " ".join(l.strip() for l in lines[idx+1:]).strip()
    return text.strip()

def qwen_chatbox(processor, model, batch_conversation_history, args):
    # Preprocess input

    text_input_list = []
    image_input_list = []

    for conv in batch_conversation_history:
        text_input_list.append(processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=True))
        img_input, _ = process_vision_info(conv)
        image_input_list.append(img_input)
    

    # 用 processor 批量处理
    inputs = processor(
        text=text_input_list,
        images=image_input_list,
        return_tensors="pt",
        padding=True
    )


    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Generate
    print("🔮 Generating...")
    generated_ids = model.generate(**inputs, max_new_tokens=args.max_tokens)
    # response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    responses = processor.batch_decode(generated_ids, skip_special_tokens=True)
    responses = [extract_assistant_response(r) for r in responses]
      

    print("\n📢 Prediction for file : ")
    print(responses)

    return responses

def qwen_chatbox_with_probs(processor, model, batch_conversation_history, args):
    # ---- (same pre-processing as existing qwen_chatbox) ----
    text_input_list, image_input_list = [], []
    for conv in batch_conversation_history:
        text_input_list.append(
            processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
        )
        img_input, _ = process_vision_info(conv)
        image_input_list.append(img_input)

    inputs = processor(
        text=text_input_list,
        images=image_input_list,
        return_tensors="pt",
        padding=True,
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # ---- generate while saving logits for each step ----
    gen_out = model.generate(
        **inputs,
        max_new_tokens=args.max_tokens,
        return_dict_in_generate=True,
        output_scores=True,          # collects logits
    )
    seqs = gen_out.sequences                 # [batch, prompt_len + gen_len]
    scores = gen_out.scores                  # list of [batch, vocab] tensors
    prompt_len = inputs["input_ids"].shape[-1]
    gen_ids = seqs[:, prompt_len:]           # [batch, gen_len]

    # ---- convert logits -> probability of each produced token ----
    token_probs = []
    for step, score in enumerate(scores):
        step_prob = score.softmax(dim=-1)    # [batch, vocab]
        tok_prob = step_prob[
            torch.arange(step_prob.size(0)),  # batch index
            gen_ids[:, step]                  # id generated at this step
        ]
        token_probs.append(tok_prob)
    token_probs = torch.stack(token_probs, dim=1)  # [batch, gen_len]

    # ---- decode to text and show per-token probability (example for batch 0) ----
    decoded = processor.batch_decode(seqs, skip_special_tokens=True)
    tok_texts = [processor.tokenizer.decode([tid]) for tid in gen_ids[0]]
    for t, p in zip(tok_texts, token_probs[0].tolist()):
        print(f"{repr(t)} : {p:.4f}")
    return decoded, gen_ids, token_probs



def main(args):
    image_folder = args.image_folder
    output_folder = args.output
    
    model_id = args.model

    os.makedirs(output_folder, exist_ok=True)

    if args.hoi_pred_json_file is None:
        print(f"🚀 Loading model: {model_id}")

        file_str = {'Qwen/Qwen2.5-VL-32B-Instruct': '60bc26b46ee6c2b83f1306042b3da1e46cb861d5', 'Qwen/Qwen2.5-VL-7B-Instruct': '5b5eecc7efc2c3e86839993f2689bbbdf06bd8d4'}
        if model_id in file_str:
            model_id = os.path.join(args.hf_pth, "hub/models--Qwen--" +model_id.split("/")[-1] + "/snapshots/"+file_str[model_id])


        if 'Qwen2.5' in args.model:
            base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                # attn_implementation="flash_attention_2",
                trust_remote_code=True
            )
        elif 'Qwen2' in args.model:
            base_model = Qwen2VLForConditionalGeneration.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                # attn_implementation="flash_attention_2",
            )

        processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
        processor.tokenizer.padding_side = "left"


        if args.lora_dir is not None:        
            model = PeftModel.from_pretrained(base_model, args.lora_dir, device_map="auto")
            model.eval()
        else:
            model = base_model


    hoi_question_sub_folder = {"none": "no_box", "person": "person_box" }
    # sec_part = args.hard_neg_part
    # if sec_part == 'none':
    anno_file_name = "replaced_merged_hoi_results_pure_negpreds_v5.json" if args.prompt_box_type == 'none'  else "merged_hoi_results_pure_negpreds_v5.json"
    # anno_file_name = "replaced_merged_hoi_results_pure_negpreds_v4.json" if args.prompt_box_type == 'none'  else "merged_hoi_results_pure_negpreds_v3.json"

    hoi_question_json_file = os.path.join(args.hoi_question_json_file, hoi_question_sub_folder[args.prompt_box_type], anno_file_name)
    
    print("Loading the annotation from", hoi_question_json_file)
    
    with open(hoi_question_json_file, 'r') as f:
        hoi_det_question = json.load(f)
        
    detection_results = None
    if args.detection_pth is not None:
        print(f"🔎 Loading detection results from: {args.detection_pth}")
        with open(args.detection_pth, 'rb') as f:
            detection_results = pickle.load(f)

    if args.previous_preds_info is not None:
        previous_preds, cnt_num = args.previous_preds_info
        generate_flag = 0
        with open(os.path.join(output_folder, str(cnt_num)+"_hoi_eval.json"), 'r') as f:
            output_dict = json.load(f)
    else:
        generate_flag = 1
        previous_preds = None
        output_dict = {}


    output_dict = {}
    cnt = 0
    one_ans_per_question = {'tp': [], 'fp': [], 'full_pred': 0, 'full_gt': 0} ### tp, fp, full_pred, full_gt, ood
    all_ans_per_question = {'tp': [], 'fp': [], 'full_pred': 0, 'full_gt': 0, 'ood': []} ### tp, fp, full_pred, full_gt, ood
    f1_per_question = []
    macro_f1_dict = {}
    acc_top1 = 0
    acc_fullmatch = 0
    hicodet_lists = [hoii['action'] + " a/an " + hoii['object'] if hoii['action'] != "no_interaction" else hoii['action'] + " with a/an " + hoii['object'] for hoii in HICO_INTERACTIONS]
    # transfer_table = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    transfer_table = {i: letter for i, letter in enumerate(string.ascii_uppercase)}

    ### load the settings

    if args.person_settings == "single":
        setting_files_pth = "outputs/hico_split_output/single_person.json"
        with open(setting_files_pth, 'r') as f:
            evaluation_files = json.load(f)
    elif args.person_settings == "multiple":
        setting_files_pth = "outputs/hico_split_output/multiple_person.json"
        with open(setting_files_pth, 'r') as f:
            evaluation_files = json.load(f)

    if args.hoi_pred_json_file is None:
        for root, _, files in os.walk(image_folder):
            for file in sorted(files):
                # if file == "HICO_test2015_00007881.jpg":
                #     continue
                if args.person_settings != "all" and file not in evaluation_files:
                    continue

                if previous_preds is not None and file == previous_preds:
                    generate_flag = 1 
                    cnt = int(cnt_num)
                    print(f"⚠️ Skipping {file} as it already exists in previous predictions.")
                    continue
                if previous_preds is not None and generate_flag == 0:
                    print(f"⚠️ Skipping {file} as it already exists in previous predictions.")
                    continue    
                

                start_time = time.time() 
                image_path = os.path.join(root, file)
                image = Image.open(image_path).convert("RGB")

                ## match with Qwen image processing
                width, height = image.size
                resized_height, resized_width = qwen_img_resize(height, width)
                image = image.resize((resized_width, resized_height))
                box_scale_factor = torch.tensor([resized_width/width, resized_height/height, resized_width/width, resized_height/height])

                ## assertion of getting questions
                if file not in hoi_det_question:
                    continue
                hoi_det_questioni = hoi_det_question[file]
                hoi_det_questioni_keys = [i for i in hoi_det_questioni]
                if "QA_0" not in hoi_det_questioni_keys:
                    hoi_det_questioni_qa = {"QA_0": hoi_det_questioni}
                else:
                    hoi_det_questioni_qa = hoi_det_questioni
                
                print("🔍 Processing file:", file)
                conversation_history_list =[]
                question_list = []
                answer_list = []
                choices_list = []

                if args.two_stage is True and args.detection_pth is None:
                    
                    prompt = newbench_detection_question(args.prompt_box_type)

                    conversation_history = [{
                        "role": "user",
                        "content": [
                            {"type": "image", "image": image},
                            {"type": "text", "text": prompt}
                        ],
                    }]

                    responses = qwen_chatbox(processor, model, [conversation_history], args)
                    det_result = parse_detection_answer(args.prompt_box_type, responses[0])
                    if len(det_result) == 0:
                        continue
                
                for qi in hoi_det_questioni_qa:
                    content_i = hoi_det_questioni_qa[qi]
                    hboxes_gt = None
                    oboxes_gt = None
                    if args.prompt_box_type != 'none':
                        hboxes_gt = content_i['boxes']["human"]
                    if args.prompt_box_type == 'all':
                        oboxes_gt = content_i['boxes']["object"]
                    gt_choices = content_i['gt_choices']
                    incorrect_choices = content_i['wrong_choices']

                    choices = gt_choices + incorrect_choices
                    random.shuffle(choices)

                    ########## batch size processing
                    ### using GT boxes
                    while len(choices) < 4:
                    # if len(choices) == 3:
                        print("☹️ missing choices and add None as one")
                        choices.append('None')
                        if len(choices) == 4: 
                            break

                    prompt1 = None
                    if args.two_stage is False:
                        prompt1 = newbench_interaction_question(choices, prompt_box_type = args.prompt_box_type, box_scale_factor = box_scale_factor, hbox = hboxes_gt, obox = oboxes_gt, box_resize=True, reasoning = args.reasoning)

                    ##### using the detected bounding boxes
                    else:
                        if args.detection_pth is not None:
                            det_result = detection_results[file][0]
                    
                        try:
                            candidate_pairs, _ = generate_candidate_pairs(det_result, human_only = (args.prompt_box_type == 'person'), prompt_box_type = args.prompt_box_type)
                        except:
                            print("❌ fail to extract detected bounding boxese with format issue")
                            continue

                        for ho_det_box_i in candidate_pairs:
                            sub_box = ho_det_box_i[0][1] if args.prompt_box_type != 'none' else None
                            obj_box = ho_det_box_i[1][1] if args.prompt_box_type == 'all' else None

                            
                            if args.detection_pth is not None:
                                gt_box_match, h_det_box_i, o_det_box_i = match_gtbox(sub_box, obj_box, hboxes_gt, oboxes_gt, args.prompt_box_type, box_scale_factor = None)
                            else:
                                gt_box_match, h_det_box_i, o_det_box_i = match_gtbox(sub_box, obj_box, hboxes_gt, oboxes_gt, args.prompt_box_type, box_scale_factor = box_scale_factor)


                            if gt_box_match == False:
                                continue

                            if args.detection_pth is not None:
                                prompt1 = newbench_interaction_question(choices, prompt_box_type = args.prompt_box_type, box_scale_factor = box_scale_factor, hbox = h_det_box_i, obox = o_det_box_i, box_resize=True, reasoning = args.reasoning)
                            else:
                                prompt1 = newbench_interaction_question(choices, prompt_box_type = args.prompt_box_type, hbox = h_det_box_i, obox = o_det_box_i, reasoning = args.reasoning)       
                            break
                    if prompt1 == None:
                        print("No matched detection boxes, skip")
                        continue
                    
                    
                    conversation_history = [{
                        "role": "user",
                        "content": [
                            {"type": "image", "image": image},
                            {"type": "text", "text": prompt1}
                        ],
                    }]
                    conversation_history_list.append(conversation_history)
                    question_list.append(qi)
                    gt_ind_i = []
                    for i in gt_choices:
                        num_i = choices.index(i)
                        gt_ind_i.append(transfer_table[num_i])
                    answer_list.append(gt_ind_i)
                    choices_list.append(label_choices(choices))
                    

                # Batch the conversation_history_list into chunks of size 6 and process each chunk
                response_list = []
                chunk_size = args.batch_size
                for i in range(0, len(conversation_history_list), chunk_size):
                    chunk = conversation_history_list[i:i + chunk_size]
                    response_chunk = qwen_chatbox(processor, model, chunk, args)
                    response_list.extend(response_chunk)

                output_dict[file] = {}
                for response_i, qli, chsi in zip(response_list, question_list, choices_list):
                    if args.reasoning != 'none':
                        answer_match = re.search(r"<answer>\s*([A-Za-z, ]+)\s*</answer>", response_i, re.IGNORECASE | re.DOTALL)
                        letters = answer_match.group(1).replace(" ", "") if answer_match else ""
                    else:
                        letters = response_i
                    ans_i = [x.strip().upper().rstrip('.') for x in letters.split(',') if x.strip()]
                    if qli not in output_dict[file]:
                        output_dict[file][qli] = [chsi[letter] for letter in ans_i if letter in chsi]
                    else:
                        output_dict[file][qli] += [chsi[letter] for letter in ans_i if letter in chsi]

                del response_list
                gc.collect()
                torch.cuda.empty_cache()

                cnt += 1
                num_written = 500 if args.two_stage is False else 200
                if cnt % num_written == 0 or cnt >= len(files):
                    out_file = os.path.join(output_folder, f"{cnt}_hoi_eval.json")
                    with open(out_file, "w") as f:
                        json.dump(output_dict, f, indent=2)
                    print(f"✅ Saved predictions to {out_file}")


                end_time = time.time()  # end time counting
                elapsed_time = end_time - start_time
                print(f"🕒 Time for {file}: {elapsed_time:.2f} seconds\n")
    
        out_file_name =  f"merged_hoi_eval.json" # if sec_part == 'none' else f"merged_hoi_eval_{sec_part}anno.json"
        out_file = os.path.join(output_folder, out_file_name) 
        with open(out_file, "w") as f:
            json.dump(output_dict, f, indent=2)
    
    else:
        with open(args.hoi_pred_json_file, 'r') as f:
            output_dict = json.load(f)
    
    ######### evaluation    
    macro_f1_dict_hoicls = {i: {'tp': 0, 'fp': 0, 'gt': 0} for i in range(args.num_hoi_cls)}
    for file in hoi_det_question:
        hoi_det_questioni = hoi_det_question[file]
        if file not in output_dict:
            response_process_list = {'QA_0': []}
        else:
            response_process_list = output_dict[file]


        if args.person_settings != "all" and file not in evaluation_files:
            continue

        f1_per_question, macro_f1_dict, all_ans_per_question, acc_top1, acc_fullmatch = mllm_instancef1_eval(hoi_det_questioni, response_process_list, f1_per_question, macro_f1_dict, all_ans_per_question, file, acc_top1, acc_fullmatch)
    #### calculate prec and recall
    ### all_ans_per_question
    prec_all_ans_per_question = len(all_ans_per_question['tp']) / (len(all_ans_per_question['tp']) + len(all_ans_per_question['fp']))
    recall_all_ans_per_question = len(all_ans_per_question['tp']) / all_ans_per_question['full_gt']
    print(f"All answers per question - Precision: {prec_all_ans_per_question:.4f}, Recall: {recall_all_ans_per_question:.4f}")

    print(f"Instance F1: {sum(f1_per_question) / len(f1_per_question):.4f}")
    

    macro_f1_dict_hoicls, macro_f1_list, prec_list, rec_list = mllm_macrof1_eval(macro_f1_dict, args.num_hoi_cls, macro_f1_dict_hoicls)  
    print(f"Macro F1: {sum(list(macro_f1_list.values())) / len(macro_f1_list):.4f}")

    micro_F1 = 2 * (prec_all_ans_per_question * recall_all_ans_per_question) / (prec_all_ans_per_question + recall_all_ans_per_question)
    print(f"Micro F1: {micro_F1:.4f}")
    
    all_question = len(hoi_det_question) if args.person_settings == "all" else len([i for i in evaluation_files if i in hoi_det_question])
    print("Top 1 prediction accuracy: ", acc_top1/all_question)
    print("Full match prediction accuracy: ", acc_fullmatch/all_question)
    
    results_txt = os.path.join(output_folder, args.person_settings + "_evaluation_results.txt")
    with open(results_txt, "w") as ftxt:
        ftxt.write(f"All answers per question - Precision: {prec_all_ans_per_question:.4f}, Recall: {recall_all_ans_per_question:.4f}\n")
        ftxt.write(f"Instance F1: {sum(f1_per_question) / len(f1_per_question):.4f}\n")
        ftxt.write(f"Micro F1 Score: {micro_F1:.4f}\n")
        ftxt.write(f"Macro F1: {sum(list(macro_f1_list.values())) / len(macro_f1_list):.4f}\n")
        ftxt.write(f"Time-ambiguous Macro F1: {sum(list(macro_f1_list_timeamb.values())) / len(macro_f1_list_timeamb):.4f}\n")
        ftxt.write(f"Averaged HOI class precision: {sum(prec_list) / len(prec_list):.4f}\n")
        ftxt.write(f"Averaged HOI class recall: {sum(rec_list) / len(rec_list):.4f}\n")
        ftxt.write(f"Top 1 prediction accuracy: {acc_top1/all_question:.4f}\n")
        ftxt.write(f"Full match prediction accuracy: {acc_fullmatch/all_question:.4f}\n")

    print("Evaluation results saved to:", results_txt)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Qwen2.5-VL-32B-Instruct for HOI detection with image input.")
    parser.add_argument("--image_folder", type=str, required=True, help="Path to the input image.")
    parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-32B-Instruct", help="Model name or path.")
    parser.add_argument("--prompt", type=str, default=None, help="Text prompt to use.")
    parser.add_argument("--max_tokens", type=int, default=2048, help="Max new tokens to generate.")
    parser.add_argument("--output", type=str, default="./output", help="Folder to save predictions.")
    parser.add_argument("--detection_pth", type=str, default=None, help="pth saved detected objects for HICO-DET.")
    parser.add_argument("--two_stage", action='store_true')
    parser.add_argument('--previous_preds_info', nargs='+', type=str, default=None)
    parser.add_argument("--hoi_question_json_file", type=str, default=None, help="Folder to get gt.")
    parser.add_argument("--hoi_pred_json_file", type=str, default=None, help="Folder to get prediction")
    parser.add_argument("--quantization", action='store_true', help="Use quantization for the model.")
    parser.add_argument("--hf_pth", type=str, default=None)
    parser.add_argument("--batch_size", type=int, default=3)
    parser.add_argument('--person_settings', type=str, default='all', choices=['multiple', 'single', 'all'], help="evaluation settings for multiple-person, single-person or all")
    parser.add_argument('--prompt_box_type', type=str, default='person', choices=['person', 'all', 'none'], help="evaluation settings with provided boxes for human or human-object or None")
    parser.add_argument("--num_hoi_cls", type=int, default=600, help="HOI classes in the benchmark.")
    parser.add_argument("--lora_dir", type=str, default=None, help="Folder for pretrained lora qwen model.")
    parser.add_argument("--reasoning", default='none', choices=['none', 'v1', 'v2'], help='none means directly output answer without reasoning, v1 is reasoning + answer, v2 is answer + reasoning')
    args = parser.parse_args()

    main(args)