print("start")
# : LlavaNextForConditionalGeneration(
#   (vision_tower): CLIPVisionModel(
#     (vision_model): CLIPVisionTransformer(
#       (embeddings): CLIPVisionEmbeddings(
#         (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
#         (position_embedding): Embedding(577, 1024)
#       )
#       (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#       (encoder): CLIPEncoder(
#         (layers): ModuleList(
#           (0-23): 24 x CLIPEncoderLayer(
#             (self_attn): CLIPSdpaAttention(
#               (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
#               (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
#               (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
#               (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
#             )
#             (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#             (mlp): CLIPMLP(
#               (activation_fn): QuickGELUActivation()
#               (fc1): Linear(in_features=1024, out_features=4096, bias=True)
#               (fc2): Linear(in_features=4096, out_features=1024, bias=True)
#             )
#             (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#           )
#         )
#       )
#       (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#     )
#   )
#   (multi_modal_projector): LlavaNextMultiModalProjector(
#     (linear_1): Linear(in_features=1024, out_features=4096, bias=True)
#     (act): GELUActivation()
#     (linear_2): Linear(in_features=4096, out_features=4096, bias=True)
#   )
#   (language_model): LlamaForCausalLM(
#     (model): LlamaModel(
#       (embed_tokens): Embedding(32064, 4096, padding_idx=0)
#       (layers): ModuleList(
#         (0-31): 32 x LlamaDecoderLayer(
#           (self_attn): LlamaAttention(
#             (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
#             (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
#             (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
#             (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
#           )
#           (mlp): LlamaMLP(
#             (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
#             (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
#             (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
#             (act_fn): SiLU()
#           )
#           (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
#           (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
#         )
#       )
#       (norm): LlamaRMSNorm((4096,), eps=1e-05)
#       (rotary_emb): LlamaRotaryEmbedding()
#     )
#     (lm_head): Linear(in_features=4096, out_features=32064, bias=False)
#   )
# )
# vision_tower: CLIPVisionModel(
#   (vision_model): CLIPVisionTransformer(
#     (embeddings): CLIPVisionEmbeddings(
#       (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
#       (position_embedding): Embedding(577, 1024)
#     )
#     (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#     (encoder): CLIPEncoder(
#       (layers): ModuleList(
#         (0-23): 24 x CLIPEncoderLayer(
#           (self_attn): CLIPSdpaAttention(
#             (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
#             (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
#             (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
#             (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
#           )
#           (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#           (mlp): CLIPMLP(
#             (activation_fn): QuickGELUActivation()
#             (fc1): Linear(in_features=1024, out_features=4096, bias=True)
#             (fc2): Linear(in_features=4096, out_features=1024, bias=True)
#           )
#           (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#         )
#       )
#     )
#     (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#   )
# )
# vision_tower.vision_model: CLIPVisionTransformer(
#   (embeddings): CLIPVisionEmbeddings(
#     (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
#     (position_embedding): Embedding(577, 1024)
#   )
#   (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#   (encoder): CLIPEncoder(
#     (layers): ModuleList(
#       (0-23): 24 x CLIPEncoderLayer(
#         (self_attn): CLIPSdpaAttention(
#           (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
#           (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
#           (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
#           (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
#         )
#         (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#         (mlp): CLIPMLP(
#           (activation_fn): QuickGELUActivation()
#           (fc1): Linear(in_features=1024, out_features=4096, bias=True)
#           (fc2): Linear(in_features=4096, out_features=1024, bias=True)
#         )
#         (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#       )
#     )
#   )
#   (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
# )
# vision_tower.vision_model.embeddings: CLIPVisionEmbeddings(
#   (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
#   (position_embedding): Embedding(577, 1024)
# )
# vision_tower.vision_model.embeddings.patch_embedding: Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
# vision_tower.vision_model.embeddings.position_embedding: Embedding(577, 1024)
# vision_tower.vision_model.pre_layrnorm: LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
# vision_tower.vision_model.encoder: CLIPEncoder(
#   (layers): ModuleList(
#     (0-23): 24 x CLIPEncoderLayer(
#       (self_attn): CLIPSdpaAttention(
#         (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
#         (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
#         (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
#         (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
#       )
#       (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#       (mlp): CLIPMLP(
#         (activation_fn): QuickGELUActivation()
#         (fc1): Linear(in_features=1024, out_features=4096, bias=True)
#         (fc2): Linear(in_features=4096, out_features=1024, bias=True)
#       )
#       (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#     )
#   )
# )
# vision_tower.vision_model.encoder.layers: ModuleList(
#   (0-23): 24 x CLIPEncoderLayer(
#     (self_attn): CLIPSdpaAttention(
#       (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
#       (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
#       (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
#       (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
#     )
#     (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#     (mlp): CLIPMLP(
#       (activation_fn): QuickGELUActivation()
#       (fc1): Linear(in_features=1024, out_features=4096, bias=True)
#       (fc2): Linear(in_features=4096, out_features=1024, bias=True)
#     )
#     (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#   )
# )
# vision_tower.vision_model.encoder.layers.0: CLIPEncoderLayer(
#   (self_attn): CLIPSdpaAttention(
#     (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
#     (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
#     (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
#     (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
#   )
#   (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#   (mlp): CLIPMLP(
#     (activation_fn): QuickGELUActivation()
#     (fc1): Linear(in_features=1024, out_features=4096, bias=True)
#     (fc2): Linear(in_features=4096, out_features=1024, bias=True)
#   )
#   (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
# )
# ***


#   )
#   (mlp): LlamaMLP(
#     (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
#     (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
#     (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
#     (act_fn): SiLU()
#   )
#   (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
#   (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
# )
# language_model.model.layers.30.self_attn: LlamaAttention(
#   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
#   (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
#   (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
#   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
# )
# language_model.model.layers.30.self_attn.q_proj: Linear(in_features=4096, out_features=4096, bias=False)
# language_model.model.layers.30.self_attn.k_proj: Linear(in_features=4096, out_features=4096, bias=False)
# language_model.model.layers.30.self_attn.v_proj: Linear(in_features=4096, out_features=4096, bias=False)
# language_model.model.layers.30.self_attn.o_proj: Linear(in_features=4096, out_features=4096, bias=False)
# language_model.model.layers.30.mlp: LlamaMLP(
#   (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
#   (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
#   (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
#   (act_fn): SiLU()
# )
# language_model.model.layers.30.mlp.gate_proj: Linear(in_features=4096, out_features=11008, bias=False)
# language_model.model.layers.30.mlp.up_proj: Linear(in_features=4096, out_features=11008, bias=False)
# language_model.model.layers.30.mlp.down_proj: Linear(in_features=11008, out_features=4096, bias=False)
# language_model.model.layers.30.mlp.act_fn: SiLU()
# language_model.model.layers.30.input_layernorm: LlamaRMSNorm((4096,), eps=1e-05)
# language_model.model.layers.30.post_attention_layernorm: LlamaRMSNorm((4096,), eps=1e-05)
# language_model.model.layers.31: LlamaDecoderLayer(
#   (self_attn): LlamaAttention(
#     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
#     (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
#     (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
#     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
#   )
#   (mlp): LlamaMLP(
#     (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
#     (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
#     (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
#     (act_fn): SiLU()
#   )
#   (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
#   (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)

# language_model.model.layers.31: LlamaDecoderLayer(
#   (self_attn): LlamaAttention(
#     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
#     (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
#     (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
#     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
#   )
#   (mlp): LlamaMLP(
#     (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
#     (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
#     (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
#     (act_fn): SiLU()
#   )
#   (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
#   (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
# )
# language_model.model.layers.31.self_attn: LlamaAttention(
#   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
#   (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
#   (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
#   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
# )
# language_model.model.layers.31.self_attn.q_proj: Linear(in_features=4096, out_features=4096, bias=False)
# language_model.model.layers.31.self_attn.k_proj: Linear(in_features=4096, out_features=4096, bias=False)
# language_model.model.layers.31.self_attn.v_proj: Linear(in_features=4096, out_features=4096, bias=False)
# language_model.model.layers.31.self_attn.o_proj: Linear(in_features=4096, out_features=4096, bias=False)
# language_model.model.layers.31.mlp: LlamaMLP(
#   (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
#   (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
#   (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
#   (act_fn): SiLU()
# )
# language_model.model.layers.31.mlp.gate_proj: Linear(in_features=4096, out_features=11008, bias=False)
# language_model.model.layers.31.mlp.up_proj: Linear(in_features=4096, out_features=11008, bias=False)
# language_model.model.layers.31.mlp.down_proj: Linear(in_features=11008, out_features=4096, bias=False)
# language_model.model.layers.31.mlp.act_fn: SiLU()
# language_model.model.layers.31.input_layernorm: LlamaRMSNorm((4096,), eps=1e-05)
# language_model.model.layers.31.post_attention_layernorm: LlamaRMSNorm((4096,), eps=1e-05)
# language_model.model.norm: LlamaRMSNorm((4096,), eps=1e-05)
# language_model.model.rotary_emb: LlamaRotaryEmbedding()
# language_model.lm_head: Linear(in_features=4096, out_features=32064, bias=False)


from transformers import AutoTokenizer
import copy

import gc
from transformers import AutoProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import json
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")

model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
model.to("cuda:0")
n=0
for name, module in model.named_modules():
    n+=1
    print(f"{name}: {module}")
    if n==10:
        break


# print('test.txt')
# tokens = [processor.tokenizer.decode([t]) for t in [29871, 13, 29888, 331, 744]]
# print(tokens)
# print("3")

residual_storage = {}

def hook_fn(layer, input, output):
    residual_storage["residual"] = input[0].detach()


def load_image(path, size=224):
    image = Image.open(path).convert("RGB")
    return image.resize((size, size))
def compare_images(image_a_path, image_b_path):
    gc.collect()
    torch.cuda.empty_cache()

    image_a = load_image(image_a_path)
    image_b = load_image(image_b_path)

    if image_a.height != image_b.height:
        image_b = image_b.resize((int(image_b.width * image_a.height / image_b.height), image_a.height))

    combined_image = Image.new("RGB", (image_a.width + image_b.width, image_a.height))
    combined_image.paste(image_a, (0, 0))
    combined_image.paste(image_b, (image_a.width, 0))
    return combined_image


def inference(dataset):
    safe_v = torch.zeros([1, 1, 4096], device=model.device)
    unsafe_v = torch.zeros([1, 1, 4096], device=model.device)
    safe_count = 0
    unsafe_count = 0
    result = {}
    for data in dataset:
        SINGLE_IMAGE = False
        TWO_IMAGE = False
        data[0]['content'][0]['text'] = base_prompt
        user_content = data[1]['content']
        if len(user_content) > 1:
            if len(user_content) == 2:
                image_path = user_content[0]['image']
                SINGLE_IMAGE = True
                img = load_image(image_path)
                data[1]['content'][0]['image'] = img

            elif len(user_content) == 3:
                image_path_a = user_content[0]['image']
                image_path_b = user_content[1]['image']
                TWO_IMAGE = True
                img = compare_images(image_path_a, image_path_b)
                data[1]['content'][0]['image'] = img
                data[1]['content'].pop(1)

        inputs = processor.apply_chat_template(
            data,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt").to(model.device, torch.float16)



        with torch.no_grad():
            outputs = model(
               **inputs,
                output_hidden_states=True,
                return_dict=True
            )

        print("***")

        next_token_logits = outputs.logits[:,-1,:]
        next_token = torch.argmax(next_token_logits, dim=-1)
        generation = processor.tokenizer.decode(next_token.tolist())
        if generation not in result.keys():
            result[generation] = 1
        else:
            result[generation] += 1





        last_hidden_state = outputs.hidden_states[-1]  # shape: [B, T, 4096]


        normed_hidden = last_hidden_state.clone()

        logits = model.language_model.lm_head(normed_hidden)  # shape: [B, T, vocab_size]
        next_token_logits = logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1)
        generation = processor.tokenizer.decode(next_token.tolist())
        if generation.strip().startswith('M') or generation.strip().startswith('m') or generation.strip().startswith('F') or generation.strip().startswith('f') or generation.strip().startswith('w')or generation.strip().startswith('W'):
            unsafe_v += normed_hidden[:,-1,:]
            unsafe_count += 1
        else:
            safe_v += normed_hidden[:,-1,:]
            safe_count += 1

        torch.cuda.empty_cache()
        gc.collect()

    safe_v_mean = safe_v / safe_count
    unsafe_v_mean = unsafe_v / unsafe_count
    return  result, safe_v_mean, unsafe_v_mean





import torch

def projection_matrix(v: torch.Tensor) -> torch.Tensor:

    v = v.squeeze()
    v_norm = v.norm()

    if v_norm == 0:
        raise ValueError("Zero vector cannot form a projection matrix.")

    v_unit = v / v_norm
    P = torch.ger(v_unit, v_unit)  # Outer product: vv^T
    return P



def orthogonal_projection_matrix(a: torch.Tensor):

    if a.ndim == 2:
        a = a.squeeze(0)  # (1, 4096) → (4096,)
    elif a.ndim != 1:
        raise ValueError(f"Invalid shape for a: {a.shape}, expected 1D or (1, 4096)")

    a_norm_squared = torch.dot(a, a)  # a^T a，标量
    P = torch.outer(a, a) / a_norm_squared  # (4096, 4096)
    return P

import torch.nn.functional as F
def remove_B_component(X, A, B):
    """
    X: (1, 280, 4096)
    A, B: (1, 1, 4096)
    """

    A_normalized = F.normalize(A, dim=-1)  # (1, 1, 4096)
    B_normalized = F.normalize(B, dim=-1)  # (1, 1, 4096)

    proj_coeff_A = torch.matmul(X, A_normalized.transpose(-1, -2))  # (1, 280, 1)
    proj_coeff_B = torch.matmul(X, A_normalized.transpose(-1, -2))  # (1, 280, 1)
    proj_A = proj_coeff_A * A_normalized

    proj_B = proj_coeff_B * B_normalized
    orth_proj = X - 0.2 * proj_B
    result = orth_proj + 0.2 * proj_A


    return proj_A, orth_proj, result


def evaluation(safe_mean, unsafe_mean, dataset):
    print("evaluation")
    length = len(dataset)
    count = 0
    for data in dataset:
        count += 1
        SINGLE_IMAGE = False
        TWO_IMAGE = False
        data[0]['content'][0]['text'] = base_prompt
        user_content = data[1]['content']
        if len(user_content) > 1:
            if len(user_content) == 2:
                image_path = user_content[0]['image']
                SINGLE_IMAGE = True
                img = load_image(image_path)
                data[1]['content'][0]['image'] = img

            elif len(user_content) == 3:
                image_path_a = user_content[0]['image']
                image_path_b = user_content[1]['image']
                TWO_IMAGE = True
                img = compare_images(image_path_a, image_path_b)
                data[1]['content'][0]['image'] = img
                data[1]['content'].pop(1)

        inputs = processor.apply_chat_template(
            data,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt").to(model.device, torch.float16)
        max_new_tokens = 50
        input_ids = inputs.input_ids
        attention_mask = inputs.attention_mask
        with torch.no_grad():
            outputs = model(
                **inputs,
                output_hidden_states=True,
                return_dict=True
            )
        print("***")
        next_token_logits = outputs.logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1)
        generation = processor.tokenizer.decode(next_token.tolist())


        if count == length:
            with torch.no_grad():
                outputs2 = model.generate(
                    **inputs,
                    max_new_tokens=100,
                )
            output_text = processor.tokenizer.decode(outputs2[0], skip_special_tokens=True)
            del outputs2
        last_hidden_state = outputs.hidden_states[-1]  # shape: [B, T, 4096]

        normed_hidden = model.language_model.model.norm(last_hidden_state)
        modified_hidden = normed_hidden.clone()


        modified_hidden_safe, modified_hidden_unsafe, result = remove_B_component(modified_hidden, safe_mean, unsafe_mean)

        logits = model.language_model.lm_head(result)  # shape: [B, T, vocab_size]
        next_token_logits = logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1)
        next_token_copy = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        generation = processor.tokenizer.decode(next_token.tolist())


        if count == length:
            input_ids = torch.cat([input_ids, next_token_copy], dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones_like(next_token_copy)], dim=1)
            for _ in range(max_new_tokens):
                with torch.no_grad():
                    outputs2 = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask
                    )
                    logits = outputs2.logits
                    next_token_logits = logits[:, -1, :]
                    next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)  # [B, 1]

                    input_ids = torch.cat([input_ids, next_token], dim=1)
                    attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=1)

            output_text = processor.tokenizer.decode(input_ids[0], skip_special_tokens=True)

        torch.cuda.empty_cache()
        gc.collect()


    return 1


# 示例调用
if __name__ == "__main__":
    # with open("prompts/base/base_gender_VL_MC_word.txt", "r", encoding="utf-8") as f:
    with open("prompts/base/base_gender_VL_MC_word.txt", "r", encoding="utf-8") as f:

        base_prompt = f.read()
        # data_path = 'data/MC/test_post1.json'
        data_path1 = '/data'
        data_path2 = '/data'

    with open(data_path1, 'r', encoding='utf-8') as f:
        dataset1 = json.load(f)
    with open(data_path2, 'r', encoding='utf-8') as f:
        dataset2 = json.load(f)

    dataset1_copy = copy.deepcopy(dataset1)
    dataset2_copy = copy.deepcopy(dataset2)
    dataset = dataset1_copy

    # safe_mean= get_safe_mean(dataset)
    result, safe_mean, unsafe_mean = inference(dataset)
    print(result)


    del dataset
    dataset_eval = dataset1
    evaluation(safe_mean, unsafe_mean, dataset_eval)

    torch.cuda.empty_cache()
