import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "4"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
import argparse
import ast
from transformers import AutoProcessor, BitsAndBytesConfig
from PIL import Image
import torch

import re 
import time
import json
from hico_text_label import MAP_AO_TO_HOI, OBJ_IDX_TO_OBJ_NAME, ACT_IDX_TO_ACT_NAME, RARE_HOI_IDX, HICO_INTERACTIONS, TIME_AMBIGUITY_HOI_INDEX
import pickle
from torchvision.ops.boxes import batched_nms, box_iou
# from script_hico_evaluation_descrip import merge_json_files
import random 
import gc
import math
import numpy as np
import torchvision.transforms as T
# from decord import VideoReader, cpu
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer, AutoConfig
from huggingface_hub import snapshot_download
from newbench_question_func import newbench_interaction_question, mllm_instancef1_eval, mllm_macrof1_eval, generate_candidate_pairs, match_gtbox, parse_detection_answer, newbench_detection_question, newbench_pre_question_imgsize, parse_imgsize_answer



IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)

    assert len(processed_images) == blocks

    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images, target_width, target_height

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images, target_width, target_height = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values, target_width, target_height

def split_model(model_name):
    device_map = {}
    world_size = torch.cuda.device_count()
    config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    num_layers = config.llm_config.num_hidden_layers
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.model.rotary_emb'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

    return device_map


def extract_assistant_response(text):
    lines = text.strip().split("\n")
    if len(lines) < 2:
        return text
    if lines[-2].strip() == 'assistant':
        return lines[-1].strip()
    if 'assistant' in lines:
        idx = lines.index('assistant')
        return " ".join(l.strip() for l in lines[idx+1:]).strip()
    return text.strip()




def main(args):
    image_folder = args.image_folder
    output_folder = args.output
    
    model_id = args.model

    os.makedirs(output_folder, exist_ok=True)

    if args.hoi_pred_json_file is None:
        print(f"🚀 Loading model: {model_id}")

        model_id = os.path.join(args.hf_home, model_id)
        device_map = split_model(model_id)

        if args.quantization:
            quant_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",           # nf4 更稳定
                bnb_4bit_compute_dtype=torch.bfloat16
            )
            
            model = AutoModel.from_pretrained(
                model_id,
                quantization_config=quant_config,
                device_map="auto",                  # 自动分配 GPU
                trust_remote_code=True
            ).eval()

            tokenizer = AutoTokenizer.from_pretrained(
                model_id,
                trust_remote_code=True
            )
        
        else:
    
            model = AutoModel.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,
                load_in_8bit=False,
                low_cpu_mem_usage=True,
                use_flash_attn=True,
                trust_remote_code=True,
                device_map=device_map).eval()
            tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)




    hoi_question_sub_folder = {"none": "no_box", "person": "person_box" }
    anno_file_name = "replaced_merged_hoi_results_pure_negpreds_v5.json" if args.prompt_box_type == 'none'  else "merged_hoi_results_pure_negpreds_v5.json"
    hoi_question_json_file = os.path.join(args.hoi_question_json_file, hoi_question_sub_folder[args.prompt_box_type], anno_file_name)
    print("Loading the annotation from", hoi_question_json_file)
    
    with open(hoi_question_json_file, 'r') as f:
        hoi_det_question = json.load(f)

    detection_results = None
    if args.detection_pth is not None:
        print(f"🔎 Loading detection results from: {args.detection_pth}")
        with open(args.detection_pth, 'rb') as f:
            detection_results = pickle.load(f)


    if args.previous_preds_info is not None:
        previous_preds, cnt_num = args.previous_preds_info
        generate_flag = 0
        with open(os.path.join(output_folder, str(cnt_num)+"_hoi_eval.json"), 'r') as f:
            output_dict = json.load(f)
    else:
        generate_flag = 1
        previous_preds = None
        output_dict = {}

    cnt = 0
    one_ans_per_question = {'tp': [], 'fp': [], 'full_pred': 0, 'full_gt': 0} ### tp, fp, full_pred, full_gt, ood
    all_ans_per_question = {'tp': [], 'fp': [], 'full_pred': 0, 'full_gt': 0, 'ood': []} ### tp, fp, full_pred, full_gt, ood
    f1_per_question = []
    macro_f1_dict = {}
    acc_top1 = 0
    acc_fullmatch = 0
    hicodet_lists = [hoii['action'] + " a/an " + hoii['object'] if hoii['action'] != "no_interaction" else hoii['action'] + " with a/an " + hoii['object'] for hoii in HICO_INTERACTIONS]
    transfer_table = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}


    ### load the settings

    if args.person_settings == "single":
        setting_files_pth = "outputs/hico_split_output/single_person.json"
        with open(setting_files_pth, 'r') as f:
            evaluation_files = json.load(f)
    elif args.person_settings == "multiple":
        setting_files_pth = "outputs/hico_split_output/multiple_person.json"
        with open(setting_files_pth, 'r') as f:
            evaluation_files = json.load(f)


    full_selections = 0
    if args.hoi_pred_json_file is None:
        for root, _, files in os.walk(image_folder):
            for file in sorted(files):
                if file not in hoi_det_question:
                    continue
                if args.person_settings != "all" and file not in evaluation_files:
                    continue

                if previous_preds is not None and file == previous_preds:
                    generate_flag = 1 
                    cnt = int(cnt_num)
                    print(f"⚠️ Skipping {file} as it already exists in previous predictions.")
                    continue
                if previous_preds is not None and generate_flag == 0:
                    print(f"⚠️ Skipping {file} as it already exists in previous predictions.")
                    continue    

                start_time = time.time() 
                image_path = os.path.join(root, file)
                image = Image.open(image_path).convert("RGB")
                ## match with Qwen image processing
                width, height = image.size
                
                pixel_values, resized_width, resized_height = load_image(image_path, max_num=12)
                pixel_values = pixel_values.to(torch.bfloat16).cuda() 

                generation_config = dict(max_new_tokens=512, do_sample=True)

                ## assertion of getting questions
                if file not in hoi_det_question:
                    continue
                hoi_det_questioni = hoi_det_question[file]
                hoi_det_questioni_keys = [i for i in hoi_det_questioni]
                if "QA_0" not in hoi_det_questioni_keys:
                    hoi_det_questioni_qa = {"QA_0": hoi_det_questioni}
                else:
                    hoi_det_questioni_qa = hoi_det_questioni

                print("🔍 Processing file:", file)
                conversation_history_list =[]
                pixel_values_list = []
                num_patches_list = []
                response_list = []
                question_list = []
                answer_list = []
                choices_list = []

                if args.prompt_box_type != "none":
                    prompt = newbench_pre_question_imgsize()
                    responses = model.chat(tokenizer, pixel_values, prompt, generation_config)
                    print("MLLM processed image size: ")
                    print(responses)
                    img_size = parse_imgsize_answer(responses)

                    if len(img_size) == 0:
                        print("Fail to get MLLM processed image size.")
                        continue
                    try:
                        resized_width, resized_height = img_size
                        box_scale_factor = torch.tensor([resized_width/width, resized_height/height, resized_width/width, resized_height/height])
                    except:
                        print("Fail to get MLLM processed image size.")
                        continue
                else:
                    box_scale_factor = None


                if args.two_stage is True and args.detection_pth is None:
                    prompt = newbench_detection_question(args.prompt_box_type)                  

                    responses = model.chat(tokenizer, pixel_values, prompt, generation_config)
                    det_result = parse_detection_answer(args.prompt_box_type, responses)
                    print("MLLM object detection results: ")
                    print(responses)

                    if len(det_result) == 0:
                        print("🙅 fail get detection results, skip")
                        continue
                    

                for qi in hoi_det_questioni_qa:
                    content_i = hoi_det_questioni_qa[qi]
                    hboxes_gt = None
                    oboxes_gt = None
                    if args.prompt_box_type != 'none':
                        hboxes_gt = content_i['boxes']["human"]
                    if args.prompt_box_type == 'all':
                        oboxes_gt = content_i['boxes']["object"]

                    gt_choices = content_i['gt_choices']
                    incorrect_choices = content_i['wrong_choices']

                    choices = gt_choices + incorrect_choices
                    random.shuffle(choices)

                    ########## batch size processing
                    ### using GT boxes
                    while len(choices) < 4:
                    # if len(choices) == 3:
                        print("☹️ missing choices and add None as one")
                        choices.append('None')
                        print(f"choices: {choices}")
                        if len(choices) == 4: 
                            break
                    prompt1 = None
                    if args.two_stage is False:
                        prompt1 = newbench_interaction_question(choices, prompt_box_type = args.prompt_box_type, box_scale_factor = box_scale_factor, hbox = hboxes_gt, obox = oboxes_gt, box_resize=True)

                    else:
                        if args.detection_pth is not None:
                            det_result = detection_results[file][0]

                        try:
                            candidate_pairs, _ = generate_candidate_pairs(det_result, human_only = (args.prompt_box_type == 'person'), prompt_box_type = args.prompt_box_type)
                        except:
                            print(" detection has format issue, skip")
                            continue
                        for ho_det_box_i in candidate_pairs:
                            sub_box = ho_det_box_i[0][1] if args.prompt_box_type != 'none' else None
                            obj_box = ho_det_box_i[1][1] if args.prompt_box_type == 'all' else None

                            if args.detection_pth is not None:
                                gt_box_match, h_det_box_i, o_det_box_i = match_gtbox(sub_box, obj_box, hboxes_gt, oboxes_gt, args.prompt_box_type, box_scale_factor = None)
                            else:
                                gt_box_match, h_det_box_i, o_det_box_i = match_gtbox(sub_box, obj_box, hboxes_gt, oboxes_gt, args.prompt_box_type, box_scale_factor)
                            
                            if gt_box_match == False:
                                continue

                            if args.detection_pth is not None:
                                prompt1 = newbench_interaction_question(choices, prompt_box_type = args.prompt_box_type, box_scale_factor = box_scale_factor, hbox = h_det_box_i, obox = o_det_box_i, box_resize=True)
                            else:
                                prompt1 = newbench_interaction_question(choices, prompt_box_type = args.prompt_box_type, hbox = h_det_box_i, obox = o_det_box_i)       
                            break


                    if prompt1 == None:
                        continue
                    

                    

                    pixel_values_list.append(pixel_values)
                    num_patches_list.append(pixel_values.size(0))
                    conversation_history_list.append(prompt1)
                    question_list.append(qi)

                    gt_ind_i = []
                    
                    # print(choices)
                    # print("GT", gt_choices)
                    if (len(choices)) != 4:
                        import pdb
                        pdb.set_trace()
                    for i in gt_choices:
                        num_i = choices.index(i)
                        gt_ind_i.append(transfer_table[num_i])
                    
                    answer_list.append(gt_ind_i)
                    choices_list.append({'A': choices[0], 'B': choices[1], 'C': choices[2], 'D': choices[3]})
                    # random.shuffle(choices)


                response_list = []
                chunk_size = args.batch_size
                for i in range(0, len(conversation_history_list), chunk_size):
                    chunk = conversation_history_list[i:i + chunk_size]
                    img_chunk = pixel_values_list[i:i + chunk_size]
                    img_chunk_tensor = torch.cat(img_chunk, dim=0)  
                    num_patches_list_chunk = num_patches_list[i:i + chunk_size]

                    response_chunk = model.batch_chat(
                        tokenizer,
                        img_chunk_tensor,
                        num_patches_list=num_patches_list_chunk,
                        questions=chunk,
                        generation_config=generation_config
                    )
                    print("\n📢 Prediction for file : ")    
                    print(response_chunk)
                    response_list.extend(response_chunk)


                output_dict[file] = {}

                for response_i, qli, chsi in zip(response_list, question_list, choices_list):
                    ans_i = [x.strip().upper().rstrip('.') for x in response_i.split(',') if x.strip()]
                    
                    if qli not in output_dict[file]:
                        output_dict[file][qli] = [chsi[letter] for letter in ans_i if letter in chsi]
                    else:
                        output_dict[file][qli] += [chsi[letter] for letter in ans_i if letter in chsi]
            

                del response_list
                gc.collect()
                torch.cuda.empty_cache()

                cnt += 1
                if cnt % 500 == 0 or cnt >= len(files):
                    out_file = os.path.join(output_folder, f"{cnt}_hoi_eval.json")
                    with open(out_file, "w") as f:
                        json.dump(output_dict, f, indent=2)
                    # output_dict = {}
                    print(f"✅ Saved predictions to {out_file}")


                end_time = time.time()  # end time counting
                elapsed_time = end_time - start_time
                print(f"🕒 Time for {file}: {elapsed_time:.2f} seconds\n")
        out_file = os.path.join(output_folder, f"merged_hoi_eval.json")
        with open(out_file, "w") as f:
            json.dump(output_dict, f, indent=2)

    else:
        with open(args.hoi_pred_json_file, 'r') as f:
            output_dict = json.load(f)
        


    ######### evaluation    
    macro_f1_dict_hoicls = {i: {'tp': 0, 'fp': 0, 'gt': 0} for i in range(args.num_hoi_cls)}
    for file in hoi_det_question:
        hoi_det_questioni = hoi_det_question[file]
        if file not in output_dict:
            response_process_list = {'QA_0': []}
        else:
            response_process_list = output_dict[file]
    

        if args.person_settings != "all" and file not in evaluation_files:
            continue

        f1_per_question, macro_f1_dict, all_ans_per_question, acc_top1, acc_fullmatch = mllm_instancef1_eval(hoi_det_questioni, response_process_list, f1_per_question, macro_f1_dict, all_ans_per_question, file, acc_top1, acc_fullmatch)
    #### calculate prec and recall
    ### all_ans_per_question
    prec_all_ans_per_question = len(all_ans_per_question['tp']) / (len(all_ans_per_question['tp']) + len(all_ans_per_question['fp']))
    recall_all_ans_per_question = len(all_ans_per_question['tp']) / all_ans_per_question['full_gt']
    print(f"All answers per question - Precision: {prec_all_ans_per_question:.4f}, Recall: {recall_all_ans_per_question:.4f}")

    print(f"Instance F1: {sum(f1_per_question) / len(f1_per_question):.4f}")
    

    macro_f1_dict_hoicls, macro_f1_list, prec_list, rec_list = mllm_macrof1_eval(macro_f1_dict, args.num_hoi_cls, macro_f1_dict_hoicls)  

    print(f"Macro F1: {sum(list(macro_f1_list.values())) / len(macro_f1_list):.4f}")
    # print(f"Averaged HOI class precision: {sum(prec_list) / len(prec_list):.4f}")
    # print(f"Averaged HOI class recall: {sum(rec_list) / len(rec_list):.4f}")

    ### time ambiguous HOI class performance macro F1
    macro_f1_list_timeamb = {k: v for k, v in macro_f1_list.items() if k in TIME_AMBIGUITY_HOI_INDEX}
    print(f"Time-ambiguous Macro F1: {sum(list(macro_f1_list_timeamb.values())) / len(macro_f1_list_timeamb):.4f}")

    micro_F1 = 2 * (prec_all_ans_per_question * recall_all_ans_per_question) / (prec_all_ans_per_question + recall_all_ans_per_question)
    print(f"Micro F1: {micro_F1:.4f}")
    
    all_question = len(hoi_det_question) if args.person_settings == "all" else len([i for i in evaluation_files if i in hoi_det_question])
    print("Top 1 prediction accuracy: ", acc_top1/all_question)
    print("Full match prediction accuracy: ", acc_fullmatch/all_question)
    
    results_txt = os.path.join(output_folder, args.person_settings + "_evaluation_results.txt")
    with open(results_txt, "w") as ftxt:
        ftxt.write(f"All answers per question - Precision: {prec_all_ans_per_question:.4f}, Recall: {recall_all_ans_per_question:.4f}\n")
        ftxt.write(f"Instance F1: {sum(f1_per_question) / len(f1_per_question):.4f}\n")
        ftxt.write(f"Micro F1 Score: {micro_F1:.4f}\n")
        ftxt.write(f"Macro F1: {sum(list(macro_f1_list.values())) / len(macro_f1_list):.4f}\n")
        ftxt.write(f"Time-ambiguous Macro F1: {sum(list(macro_f1_list_timeamb.values())) / len(macro_f1_list_timeamb):.4f}\n")
        ftxt.write(f"Averaged HOI class precision: {sum(prec_list) / len(prec_list):.4f}\n")
        ftxt.write(f"Averaged HOI class recall: {sum(rec_list) / len(rec_list):.4f}\n")
        ftxt.write(f"Top 1 prediction accuracy: {acc_top1/all_question:.4f}\n")
        ftxt.write(f"Full match prediction accuracy: {acc_fullmatch/all_question:.4f}\n")

    

     

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run OpenGVLab/InternVL3-38B for HOI detection with image input.")
    parser.add_argument("--image_folder", type=str, required=True, help="Path to the input image.")
    parser.add_argument("--model", type=str, default="OpenGVLab/InternVL3-38B", help="Model name or path.")
    parser.add_argument("--prompt", type=str, default=None, help="Text prompt to use.")
    parser.add_argument("--max_tokens", type=int, default=2048, help="Max new tokens to generate.")
    parser.add_argument("--output", type=str, default="./output", help="Folder to save predictions.")
    parser.add_argument("--detection_pth", type=str, default=None, help="pth saved detected objects for HICO-DET.")
    parser.add_argument("--two_stage", action='store_true')
    parser.add_argument('--previous_preds_info', nargs='+', type=str, default=None)
    parser.add_argument("--hoi_question_json_file", type=str, default=None, help="Folder to get gt.")
    parser.add_argument("--hoi_pred_json_file", type=str, default=None, help="Folder to get gt.")
    parser.add_argument("--quantization", action='store_true', help="Use quantization for the model.")
    parser.add_argument("--hf_pth", type=str, default=None)
    parser.add_argument("--batch_size", type=int, default=3)
    parser.add_argument('--person_settings', type=str, default='all', choices=['multiple', 'single', 'all'], help="evaluation settings for multiple-person, single-person or all")
    parser.add_argument('--prompt_box_type', type=str, default='person', choices=['person', 'all', 'none'], help="evaluation settings with provided boxes for human or human-object or None")
    parser.add_argument("--num_hoi_cls", type=int, default=600, help="HOI classes in the benchmark.")
    parser.add_argument("--hf_home", type=str, default=None, help="Folder to get hungging face checkpoints.")
    
    args = parser.parse_args()

    main(args)