from transformers import AutoProcessor, Gemma3ForConditionalGeneration

from PIL import Image
import requests
import copy
import torch
from tqdm import tqdm

import sys
import warnings
import json
import os
import time


import argparse

import base64
import io



parser = argparse.ArgumentParser(description="Gemma")
parser.add_argument('--read_dir', type=str, default=None, required=True, help='Directory containing the input JSON files.')
parser.add_argument('--image_dir', type=str, default='', help='Image directory (if applicable). None if images are embedded in JSON.')
parser.add_argument('--model_size', type=str, default=None, required=True, help='Size of gemma')
parser.add_argument('--data_mode', type=str, default=None, required=True, choices=["train", "val"], help='Data mode ("train" or "val").')
parser.add_argument('--vector_save_dir', type=str, required=True, help='Directory to save the output statistics files.')


args = parser.parse_args()


read_dir = args.read_dir
image_base_dir = args.image_dir
model_size = args.model_size
data_mode = args.data_mode
vector_save_dir = args.vector_save_dir

model_name = f"google/gemma-3-{model_size}b-it"

# Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration


model = Gemma3ForConditionalGeneration.from_pretrained(
    model_name,
    device_map="auto"
).eval()





processor = AutoProcessor.from_pretrained(model_name)

# print('Processor:', processor)


tokenizer = processor.tokenizer
# print('Tokenizer from processor:', tokenizer)
# print('Tokenizer:', tokenizer)


input_json_path = read_dir + data_mode + ".json"  # 2000 samples now!!!!



warnings.filterwarnings("ignore")


device = "cuda"
device_map = "auto"


# print('Image token id: ', model.config.image_token_id)

print('------------------')
print(model_name)
print('------------------')

with open(input_json_path, 'r') as f:
    data = json.load(f)


text_embedding_list = []
image_embedding_list = []


count = 0


for item in tqdm(data):

    # image_path = os.path.join(image_base_dir, item['image'])

    if image_base_dir:
        image_path = os.path.join(image_base_dir, item['image'])
        image_ori = Image.open(image_path).convert("RGB")
    else:
        img_bytes = base64.b64decode(item['image'])
        image_ori = Image.open(io.BytesIO(img_bytes))
    


    # Get question from conversations
    question = None
    for conv in item['conversations']:
        if conv['from'] == 'human':
            question = conv['value'].replace('\n<image>', '')
            break
    
    
    message_dual = [
    {
        "role": "system",
        "content": [{"type": "text", "text": "You are a helpful assistant."}]
    },
    {
        "role": "user",
        "content": [
            {"type": "image", "image": image_ori},
            {"type": "text", "text": question}
        ]
    }
    ]

    # print('question:', question)


    inputs = processor.apply_chat_template(
        message_dual,
        tokenize=True,
        add_generation_prompt=True,
        return_dict=True,
        return_tensors="pt"
    ).to(model.device, dtype=torch.bfloat16)


    # print('input id shape: ', inputs['input_ids'].shape)
    # print('input id: ', inputs['input_ids'])


    inputs_embeds_full = model.model.get_input_embeddings()(inputs['input_ids'])

    text_ind_1 = (inputs['input_ids'][0] == 14977).nonzero(as_tuple=True)[0].item()
    text_ind_2 = (inputs['input_ids'][0] == 236761).nonzero(as_tuple=True)[0][-1].item()

    image_ind_1 = (inputs['input_ids'][0] == 255999).nonzero(as_tuple=True)[0].item()
    image_ind_2 = (inputs['input_ids'][0] == 256000).nonzero(as_tuple=True)[0].item()

    text_feature = inputs_embeds_full[0, text_ind_1:text_ind_2+1, :]
    image_feature = inputs_embeds_full[0, image_ind_1+1:image_ind_2, :]


    # print('text_feature shape: ', text_feature.shape)
    # print('image_feature shape: ', image_feature.shape)


    text_embedding_list.append(text_feature.detach().cpu())
    image_embedding_list.append(image_feature.detach().cpu())


    count += 1


    # count += 1
    # if count > 10:
    #     break



text_all_embedding = torch.cat(text_embedding_list, dim=0)
image_all_embedding = torch.cat(image_embedding_list, dim=0)


print('text_all_embedding shape:', text_all_embedding.shape)
print('image_all_embedding shape:', image_all_embedding.shape)
print('Current model and data mode:', model_name, data_mode)

print("\n" + "="*50)
print("--- Detailed Per-Feature-Dimension Statistics ---")


text_mean_per_dim = text_all_embedding.mean(dim=0)
text_std_per_dim = text_all_embedding.std(dim=0, unbiased=False)
print(f"\nText Mean Vector (Shape: {text_mean_per_dim.shape})")
print(f"Text STD Vector (Shape:  {text_std_per_dim.shape})")
print(f"  - Mean of Text Mean Vector: {text_mean_per_dim.mean().item():.6f}")
print(f"  - Mean of Text STD Vector:  {text_std_per_dim.mean().item():.6f}")



image_mean_per_dim = image_all_embedding.mean(dim=0)
image_std_per_dim = image_all_embedding.std(dim=0, unbiased=False)
print(f"\nImage Mean Vector (Shape: {image_mean_per_dim.shape})")
print(f"Image STD Vector (Shape:  {image_std_per_dim.shape})")
print(f"  - Mean of Image Mean Vector: {image_mean_per_dim.mean().item():.6f}")
print(f"  - Mean of Image STD Vector:  {image_std_per_dim.mean().item():.6f}")


print("\n" + "="*50)

save_dir = vector_save_dir
print(f"Saving statistics to directory: {save_dir}")


clean_model_name = model_name.replace('/', '_')


text_stats_path = os.path.join(save_dir, f"{clean_model_name}_{data_mode}_text_stats.pt")
torch.save({
    'mean': text_mean_per_dim,
    'std': text_std_per_dim
}, text_stats_path)
print(f"-> Saved text statistics to: {text_stats_path}")


image_stats_path = os.path.join(save_dir, f"{clean_model_name}_{data_mode}_image_stats.pt")
torch.save({
    'mean': image_mean_per_dim,
    'std': image_std_per_dim
}, image_stats_path)
print(f"-> Saved image statistics to: {image_stats_path}")
print("="*50)