import argparse
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from PIL import Image
import torch
from qwen_vl_utils import process_vision_info
import os
import re 
import time
import json
from hico_text_label import MAP_AO_TO_HOI, OBJ_IDX_TO_OBJ_NAME, ACT_IDX_TO_ACT_NAME, RARE_HOI_IDX, HICO_INTERACTIONS, TIME_AMBIGUITY_HOI_INDEX
import pickle
from torchvision.ops.boxes import batched_nms, box_iou
import random 
from transformers import BitsAndBytesConfig
import gc
from torchvision.ops import batched_nms
from newbench_question_func import mllm_instancef1_eval, mllm_macrof1_eval, match_gtbox

def format_bbox(bbox):
    """
    Format a bounding box list such that each number is represented with 2 decimal places.
    Example: [262.80762, 301.31366, 314.68677, 351.58646] -> "[262.81, 301.31, 314.69, 351.59]"
    """
    try:
        return "[" + ", ".join(f"{x:.2f}" for x in bbox) + "]"
    except Exception as e:
        # Fallback: if formatting fails, return the original bounding box.
        return str(bbox)

def get_thres(output_dict, args):

    HOI_cls_predscore = {i: [] for i in range(args.num_hoi_cls)}
    
    for fi in output_dict:
        response_process_list = output_dict[fi]
        ho_scores = torch.tensor(response_process_list["ho_scores"], dtype=torch.float32)
        ao_names = response_process_list['ao_names']
        for sci, lbi in zip(ho_scores, ao_names):
            act, obj = lbi.split(" a/an ")
            act = act.replace(" ", "_")
            obj = obj.replace(" ", "_")
            if "no_interaction" in act:
                act = "no_interaction"

            act_no = ACT_IDX_TO_ACT_NAME.index(act)
            obj_no = OBJ_IDX_TO_OBJ_NAME.index(obj)
            hoii = int(MAP_AO_TO_HOI[act_no, obj_no])
            HOI_cls_predscore[hoii].append(sci)

    return HOI_cls_predscore


def main(args):
    image_folder = args.image_folder
    output_folder = args.output
  
    # model_id = args.model

    os.makedirs(output_folder, exist_ok=True)

    hoi_question_sub_folder = {"none": "no_box", "person": "person_box" }
    # anno_file_name = "replaced_merged_hoi_results_pure_negpreds_v4.json" if args.prompt_box_type == 'none'  else "merged_hoi_results_pure_negpreds_v3.json"
    anno_file_name = "replaced_merged_hoi_results_pure_negpreds_v5.json" if args.prompt_box_type == 'none'  else "merged_hoi_results_pure_negpreds_v5.json"

    hoi_question_json_file = os.path.join(args.hoi_question_json_file, hoi_question_sub_folder[args.prompt_box_type], anno_file_name)
    print("Loading the annotation from", hoi_question_json_file)

    with open(hoi_question_json_file, 'r') as f:
        hoi_det_question = json.load(f)

    if args.previous_preds_info is not None:
        previous_preds, cnt_num = args.previous_preds_info
        generate_flag = 0
    else:
        generate_flag = 1
        previous_preds = None

    output_dict = {}
    cnt = 0
    all_ans_per_question = {'tp': [], 'fp': [], 'full_pred': 0, 'full_gt': 0, 'ood': []} ### tp, fp, full_pred, full_gt, ood
    f1_per_question = []
    macro_f1_dict = {}
    acc_top1 = 0
    acc_fullmatch = 0
    hicodet_lists = [hoii['action'] + " a/an " + hoii['object'] if hoii['action'] != "no_interaction" else hoii['action'] + " with a/an " + hoii['object'] for hoii in HICO_INTERACTIONS]
   
    with open(args.hoi_pred_json_file, 'r') as f:
        output_dict = json.load(f)

    if args.person_settings == "single":
        setting_files_pth = "outputs/hico_split_output/single_person.json"
        with open(setting_files_pth, 'r') as f:
            evaluation_files = json.load(f)
    elif args.person_settings == "multiple":
        setting_files_pth = "outputs/hico_split_output/multiple_person.json"
        with open(setting_files_pth, 'r') as f:
            evaluation_files = json.load(f)

    if args.pred_select == 'rank':
        ### iterate whole prediction in test set
        HOI_cls_predscore = get_thres(output_dict, args)
        all_preds = torch.tensor([item for sublist in HOI_cls_predscore.values() for item in sublist])
        rank_pred_num = min(int(args.pred_thres), len(all_preds))
        thres_calc, _ = all_preds.topk(rank_pred_num)
        thres_calc = thres_calc[-1].item()
    else:
        thres_calc = args.pred_thres


    fcnt = 0
    newbench_output_dict = {}
    for file in output_dict:
        fcnt += 1
        # if file != "HICO_test2015_00000058.jpg":
        #     continue
        if file not in hoi_det_question:
            continue
        if args.person_settings != "all" and file not in evaluation_files:
            continue
        newbench_output_dict[file] = {}


        response_process_list = output_dict[file]
        print("🔍 Processing file:", file)

        
        hboxes = torch.tensor(response_process_list["h_boxes"], dtype=torch.float32)
        oboxes = torch.tensor(response_process_list["o_boxes"], dtype=torch.float32)
        ho_scores = torch.tensor(response_process_list["ho_scores"], dtype=torch.float32)

        
        hoi_det_questioni = hoi_det_question[file]
        hoi_det_questioni_keys = [i for i in hoi_det_questioni]
        if "QA_0" not in hoi_det_questioni_keys:
            hoi_det_questioni_qa = {"QA_0": hoi_det_questioni}
        else:
            hoi_det_questioni_qa = hoi_det_questioni
        
        ### process the HOI prediction for this question
        for qli in hoi_det_questioni_qa:
            newbench_output_dict[file][qli] = []
            if len(hboxes) == 0 or len(oboxes) == 0:
                continue

            content_i = hoi_det_questioni_qa[qli]
            if args.prompt_box_type != 'none':
                gt_hboxes = content_i['boxes']['human']
                gt_oboxes = content_i['boxes']['object'] if args.prompt_box_type == 'all' else None
                gt_pred_match, _, _ = match_gtbox(hboxes, oboxes, gt_hboxes, gt_oboxes, args.prompt_box_type, box_scale_factor=None)
                pred_ind = torch.nonzero(gt_pred_match, as_tuple=False)
                pred_ind = pred_ind[:, 1] if pred_ind.ndim == 2 else pred_ind[1]

                predi = [response_process_list['ao_names'][i] for i in pred_ind.tolist()]
                scoresi = [ho_scores[i] for i in pred_ind.tolist()]
            else:
                predi = response_process_list['ao_names']
                scoresi = ho_scores

            if args.pred_select == 'question_rank':
                max_rank_qi = min(int(args.pred_thres), len(ho_scores))
                thres_calc = ho_scores.topk(max_rank_qi).values[-1]
            for idx_rpi, rpii in enumerate(predi):
                if scoresi[idx_rpi] < thres_calc:
                    continue
                newbench_output_dict[file][qli].append(rpii)

        
        f1_per_question, macro_f1_dict, all_ans_per_question, acc_top1, acc_fullmatch = mllm_instancef1_eval(hoi_det_questioni, newbench_output_dict[file], f1_per_question, macro_f1_dict, all_ans_per_question, file, acc_top1, acc_fullmatch)


         
    if args.save_pred is True:
        os.makedirs(output_folder, exist_ok=True)
        out_file = os.path.join(output_folder, f"{fcnt-1}_hoi_eval_{args.pred_thres}.json")
        with open(out_file, "w") as f:
            json.dump(newbench_output_dict, f, indent=2)

    #### calculate prec and recall
    ### all_ans_per_question
    prec_all_ans_per_question = len(all_ans_per_question['tp']) / (len(all_ans_per_question['tp']) + len(all_ans_per_question['fp']))
    recall_all_ans_per_question = len(all_ans_per_question['tp']) / all_ans_per_question['full_gt']
    print(f"All answers per question - Precision: {prec_all_ans_per_question:.4f}, Recall: {recall_all_ans_per_question:.4f}")

    print(f"Instance F1: {sum(f1_per_question) / len(f1_per_question):.4f}")


    macro_f1_dict_hoicls = {i: {'tp': 0, 'fp': 0, 'gt': 0} for i in range(args.num_hoi_cls)}
    
    macro_f1_dict_hoicls, macro_f1_list, prec_list, rec_list = mllm_macrof1_eval(macro_f1_dict, args.num_hoi_cls, macro_f1_dict_hoicls)  
    
    print(f"Macro F1: {sum(list(macro_f1_list.values())) / len(macro_f1_list):.4f}")

    ### time ambiguous HOI class performance macro F1
    macro_f1_list_timeamb = {k: v for k, v in macro_f1_list.items() if k in TIME_AMBIGUITY_HOI_INDEX}
    # print(f"Time-ambiguous Macro F1: {sum(list(macro_f1_list_timeamb.values())) / len(macro_f1_list_timeamb):.4f}")

    micro_F1 = 2 * (prec_all_ans_per_question * recall_all_ans_per_question) / (prec_all_ans_per_question + recall_all_ans_per_question)
    print(f"Micro F1 Score: {micro_F1:.4f}")


    all_question = len(hoi_det_question) if args.person_settings == "all" else len([i for i in evaluation_files if i in hoi_det_question])
    # print("Top 1 prediction accuracy: ", acc_top1/all_question)
    print("Full match prediction accuracy: ", acc_fullmatch/all_question)
    
    results_txt = os.path.join(output_folder, args.person_settings + "_evaluation_results.txt")
    with open(results_txt, "w") as ftxt:
        ftxt.write(f"All answers per question - Precision: {prec_all_ans_per_question:.4f}, Recall: {recall_all_ans_per_question:.4f}\n")
        ftxt.write(f"Instance F1: {sum(f1_per_question) / len(f1_per_question):.4f}\n")
        ftxt.write(f"Micro F1 Score: {micro_F1:.4f}\n")
        ftxt.write(f"Macro F1: {sum(list(macro_f1_list.values())) / len(macro_f1_list):.4f}\n")
        ftxt.write(f"Time-ambiguous Macro F1: {sum(list(macro_f1_list_timeamb.values())) / len(macro_f1_list_timeamb):.4f}\n")
        ftxt.write(f"Averaged HOI class precision: {sum(prec_list) / len(prec_list):.4f}\n")
        ftxt.write(f"Averaged HOI class recall: {sum(rec_list) / len(rec_list):.4f}\n")
        ftxt.write(f"Top 1 prediction accuracy: {acc_top1/all_question:.4f}\n")
        ftxt.write(f"Full match prediction accuracy: {acc_fullmatch/all_question:.4f}\n")



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Qwen2.5-VL-32B-Instruct for HOI detection with image input.")
    parser.add_argument("--image_folder", type=str, required=True, help="Path to the input image.")
    parser.add_argument("--prompt", type=str, default=None, help="Text prompt to use.")
    parser.add_argument("--max_tokens", type=int, default=2048, help="Max new tokens to generate.")
    parser.add_argument("--output", type=str, default="./output", help="Folder to save predictions.")
    parser.add_argument("--detection_pth", type=str, default=None, help="pth saved detected objects for HICO-DET.")
    parser.add_argument('--previous_preds_info', nargs='+', type=str, default=None)
    parser.add_argument("--hoi_question_json_file", type=str, default=None, help="Folder to get gt.")
    parser.add_argument("--hoi_pred_json_file", type=str, default=None, help="Folder to get gt.")
    parser.add_argument("--pred_thres", type=float, default=0.5, help="prediction threshold / ranking for HOI methods")
    parser.add_argument("--save_pred", action='store_true')
    parser.add_argument('--person_settings', type=str, default='all', choices=['multiple', 'single', 'all'], help="evaluation settings for multiple-person, single-person or all")
    parser.add_argument('--prompt_box_type', type=str, default='person', choices=['person', 'all', 'none'], help="evaluation settings with provided boxes for human or human-object or None")
    parser.add_argument('--pred_select', type=str, default='question_rank', choices=['rank', 'thres', 'question_rank'], help="selecting HOI prediction")
    parser.add_argument("--num_hoi_cls", type=int, default=600, help="HOI classes in the benchmark.")

    args = parser.parse_args()

    main(args)