

from transformers import AutoTokenizer
import copy

import gc
from transformers import AutoProcessor, LlavaForConditionalGeneration, LlavaNextForConditionalGeneration
from PIL import Image
import json
# 加载 system prompt

import torch



#
# model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
# processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")


# model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-13b-hf",device_map="cuda", torch_dtype=torch.float16)
# processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-13b-hf")



model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
#
#
#
#
# model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf",device_map="cuda", torch_dtype=torch.float16)
# processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
#
#






model.to("cuda:0")
total_layers = len(model.language_model.model.layers)
 # [batch_size, seq_len, vocab_size]

print("start")


def load_image(path, size=224):
    image = Image.open(path).convert("RGB")
    return image.resize((size, size))

def compare_images(image_a_path, image_b_path):
    gc.collect()
    torch.cuda.empty_cache()

    image_a = load_image(image_a_path)
    image_b = load_image(image_b_path)

    if image_a.height != image_b.height:
        image_b = image_b.resize((int(image_b.width * image_a.height / image_b.height), image_a.height))

    combined_image = Image.new("RGB", (image_a.width + image_b.width, image_a.height))
    combined_image.paste(image_a, (0, 0))
    combined_image.paste(image_b, (image_a.width, 0))
    return combined_image






def inference(dataset):
    safe_v = torch.zeros([1, 1, 4096], device=model.device)
    unsafe_v = torch.zeros([1, 1, 4096], device=model.device)
    safe_count = 0
    unsafe_count = 0
    result = {}

    # self_attn_input_list = []
    # self_attn_output_list = []
    #
    # def hook_fn(module, input, output):
    #     self_attn_input_list.append(input[0].detach().cpu())
    #     self_attn_output_list.append(output.detach().cpu())
    #
    # hook_handle = model.language_model.model.layers[31].self_attn.register_forward_hook(hook_fn)

    # self_attn_input_list = []
    self_attn_output_list = []
    mlp_input_list = []
    mlp_output_list = []

    def hook_self_attn(module, input, output):
        # self_attn_input_list.append(input[0].detach().cpu())
        self_attn_output_list.append(output[0].detach().cpu())

    def hook_mlp(module, input, output):
        mlp_input_list.append(input[0].detach().cpu())
        mlp_output_list.append(output.detach().cpu())

    count_ori = 0
    count_new = 0
    result = []
    for data in dataset:
        SINGLE_IMAGE = False
        TWO_IMAGE = False
        data[0]['content'][0]['text'] = base_prompt
        user_content = data[1]['content']
        if len(user_content) > 1:
            if len(user_content) == 2:
                image_path = user_content[0]['image']
                SINGLE_IMAGE = True
                img = load_image(image_path)
                data[1]['content'][0]['image'] = img

            elif len(user_content) == 3:
                image_path_a = user_content[0]['image']
                image_path_b = user_content[1]['image']
                TWO_IMAGE = True
                img = compare_images(image_path_a, image_path_b)
                data[1]['content'][0]['image'] = img
                data[1]['content'].pop(1)

        inputs = processor.apply_chat_template(
            data,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            max_new_tokens=50,
            return_tensors="pt").to(model.device, torch.float16)

        # test_layer = 4
        temp_list = []

        # residuals_layer = total_layers - 1
        #
        # residuals_layer -= test_layer
        # hook_self_attn_handle = model.language_model.model.layers[residuals_layer].self_attn.register_forward_hook(hook_self_attn)
        # hook_mlp_handle = model.language_model.model.layers[residuals_layer].mlp.register_forward_hook(hook_mlp)

        with torch.no_grad():
            outputs = model(
               **inputs,
                output_hidden_states=True
            )
        num_hidden_states = len(outputs.hidden_states)
        for i in range(num_hidden_states - 10, num_hidden_states, 1):


            print(f"xx")


            hidden_states_temp = outputs.hidden_states[i]

            if i != num_hidden_states - 1:
                hidden_norm = model.language_model.model.norm(hidden_states_temp)
                logits_orig = model.language_model.lm_head(hidden_norm)


            if i == num_hidden_states - 1:
                logits_orig = model.language_model.lm_head(hidden_states_temp)

            next_token_logits = logits_orig[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1)
            generation = processor.tokenizer.decode(next_token.tolist())

            generation = generation.strip().lower()
            temp_list.append(generation)
            for t in ['both', 'similar','equal']:
                if generation in t:
                    count_ori += 1
                    break
        result.append(temp_list)
    final_result = []
    for sample in result:
        sample_list = []
        for layer in sample:
            layer = layer.strip().lower()
            label = False
            for t in ['both', 'similar', 'equal']:
                if layer in t and layer != "":
                    label = True
                    break
            sample_list.append(label)
        final_result.append(sample_list)

    def count_true_per_position(data):
        num_cols = len(data[0])
        return [sum(row[col] for row in data) for col in range(num_cols)]
    print(count_true_per_position(final_result))



        # mlp_input = mlp_input_list[-1].to(model.device)
        # mlp_output = mlp_output_list[-1].to(model.device)
        # # print(len(outputs.hidden_states))
        # # print(len(model.language_model.model.layers))
        #
        # # selfattn_input = self_attn_input_list[-1].to(model.device)
        # selfattn_output = self_attn_output_list[-1].to(model.device)
        #
        #
        # print(len(mlp_input_list))
        #
        # # mlp_input = mlp_input_list[0].to(model.device)
        # # mlp_input = model.language_model.model.layers[31].post_attention_layernorm(self_attn_out)
        #
        #
        # gate = model.language_model.model.layers[residuals_layer].mlp.gate_proj(mlp_input)  # [B, T, 11008]
        # up = model.language_model.model.layers[residuals_layer].mlp.up_proj(mlp_input)  # [B, T, 11008]
        # gated = model.language_model.model.layers[residuals_layer].mlp.act_fn(gate) * up  # [B, T, 11008]
        # y = model.language_model.model.layers[residuals_layer].mlp.down_proj(gated)  # [B, T, 4096]
        # print(y.size())
        #
        # # f = y + hidden_states + selfattn_output
        #
        # # f = y
        # # f = hidden_states
        # f = selfattn_output
        #
        # c = model.language_model.model.norm(f)  # [B, T, 4096]
        #
        #
        # logits = model.language_model.lm_head(c)  # [B, T, vocab_size]
        # print(logits.size())
        #
        #
        #
        #
        #
        # next_token_logits = logits[:, -1, :]
        # next_token = torch.argmax(next_token_logits, dim=-1)
        # generation2 = processor.tokenizer.decode(next_token.tolist())
        # generation2 = generation2.strip().lower()
        # for t in ['both', 'similar', 'equal']:
        #     if generation2 in t:
        #         count_new += 1
        # print(count_ori)
        # print(count_new)
        #
        #
        #
        #
        #
        # hook_self_attn_handle.remove()
        # hook_mlp_handle.remove()
    torch.cuda.empty_cache()
    gc.collect()


    return 1













if __name__ == "__main__":
    # with open("prompts/base/base_gender_VL_MC_word.txt", "r", encoding="utf-8") as f:
    with open("prompts/base/base_gender_VL_MC_word.txt", "r", encoding="utf-8") as f:

        base_prompt = f.read()
        # data_path = 'data/MC/test_post1.json'
        data_path1 = '/data'
        data_path2 = 'data'

    with open(data_path1, 'r', encoding='utf-8') as f:
        dataset1 = json.load(f)
    with open(data_path2, 'r', encoding='utf-8') as f:
        dataset2 = json.load(f)

    dataset1_copy = copy.deepcopy(dataset1)
    dataset2_copy = copy.deepcopy(dataset2)
    dataset = dataset1_copy


    inference(dataset1)

