
import os
import torch
import tqdm
import re
import json
from collections import defaultdict

dataset_path = "./dataset/sftdatabbox.json"
teacher_path = "./dataset/teacher_attn_map"
semantic_path = "./dataset/adversarial_attn_map/"
cross_instance_dict_path = "./dataset/cross_instance_dict.json"
save_dir = "./dataset/neg_attn_map/" 



def process_cross_instance_dict(input_file=dataset_path):
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    block_dict = defaultdict(list)

    dataSize = len(data)
    for idx, item in enumerate(tqdm(data)):
        print(f"{idx+1}/{dataSize}")
        conversations = item["conversations"]
        image_path = item["image"][0]

        # Get block_id
        match_id = re.search(r'/([^/_]+_block_\d+)_\d+/', image_path)
        block_id = match_id.group(1)

        # Get image_name
        match_image_name = re.search(r'/([^/]+)\.jpg$', image_path)
        image_name = match_image_name.group(1)

        for conv in conversations:
            if conv["from"] == "user" and "[Details of the Target]" in conv["value"]:
                block_dict[block_id].append(image_name)

    # Save all samples corresponding to block_id
    with open("block_details_texts.json", "w", encoding="utf-8") as f:
        json.dump(block_dict, f, ensure_ascii=False, indent=2)

    print("Dictionary construction completed")

    # Build image_name -> other image_names under the same block
    cross_instance_dict = {}

    for idx, item in enumerate(tqdm(data)):
        image_path = item["image"][0]
        match_id = re.search(r'/([^/_]+_block_\d+)_\d+/', image_path)
        block_id = match_id.group(1)

        match_image_name = re.search(r'/([^/]+)\.jpg$', image_path)
        image_name = match_image_name.group(1)

        # All samples under the same block (excluding itself)
        cross_instance_dict[image_name] = [
            img for img in block_dict[block_id] if img != image_name
        ][:64]

    with open(cross_instance_dict_path, "w", encoding="utf-8") as f:
        json.dump(cross_instance_dict, f, ensure_ascii=False, indent=2)



def merge_neg_data():

    os.makedirs(save_dir, exist_ok=True)
    file_names = os.listdir(teacher_path)

    with open(cross_instance_dict_path, "r", encoding="utf-8") as f:
        cross_instance_dict = json.load(f)

    for idx, image_path in enumerate(tqdm.tqdm(file_names)):

        image_name = image_path[:-3]
        neg_attn = []
        teacher_attn_scores = torch.load(teacher_path + "/" + image_path, weights_only=True)

        #Cross-instance attention
        if len(cross_instance_dict[image_name]) > 0:
            for curr in cross_instance_dict[image_name]:
                same_len_path = teacher_path + "/" + curr + ".pt"
                same_len_attn_scores = torch.load(same_len_path, weights_only=True).mean(dim=0)
                neg_attn.append(same_len_attn_scores)

        #Adversarial Attention:
        adversarial_attn_scores = torch.load(semantic_path + "/" + image_path, weights_only=True).mean(dim=0)
        neg_attn.append(adversarial_attn_scores)

        #Perturbed attention:
        for _ in range(8):
            noise_teacher_attn_scores = torch.randn_like(teacher_attn_scores)
            noise_teacher_attn_scores = torch.abs(noise_teacher_attn_scores)
            noise_teacher_attn_scores = noise_teacher_attn_scores / (noise_teacher_attn_scores.sum(dim=-1, keepdim=True) + 1e-8)
            teacher_attn_scores = teacher_attn_scores / (teacher_attn_scores.sum(dim=-1, keepdim=True) + 1e-8)
            noise_sigma = 0.5
            noise_teacher_attn_scores = noise_sigma * noise_teacher_attn_scores + teacher_attn_scores
            neg_attn.append(noise_teacher_attn_scores)
        
        #Random attention:
        rand = torch.randn((64 - len(neg_attn), *teacher_attn_scores.shape), device=teacher_attn_scores.device)
        rand = torch.abs(rand)
        rand = rand / (rand.sum(dim=-1, keepdim=True) + 1e-8)
        neg_attn.extend(list(rand))

        torch.save(neg_attn, save_dir + f"{image_path}")

    

if __name__ == "__main__":
    process_cross_instance_dict()
    merge_neg_data()


