from transformers import AutoProcessor, Gemma3ForConditionalGeneration

from PIL import Image
import requests
import copy
import torch
from tqdm import tqdm

import sys
import warnings
import json
import os
import time
import argparse



import base64
import io


torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

print(f"Random seed set to {42} for reproducibility.")


##### Check add or replace carefully!!!!


parser = argparse.ArgumentParser(description="Run Gemma inference with specified parameters.")
parser.add_argument('--read_dir', type=str, default=None, required=True, help='Directory containing the input JSON files.')
parser.add_argument('--image_dir', type=str, default='', help='Image directory (if applicable). None if images are embedded in JSON.')
parser.add_argument('--save_dir', type=str, default=None, required=True, help='Directory to save the output JSON files.')
parser.add_argument('--num_choices', type=int, default=4, help='Number of choices for the model to predict (default: 4).')
parser.add_argument('--model_size', type=str, default=None, required=True, help='Size of gemma')
parser.add_argument('--data_mode', type=str, default=None, required=True, choices=["train", "val"], help='Data mode ("train" or "val").')
parser.add_argument('--devi_vector_dir', type=str, default=None, required=True, help='Directory containing the pre-calculated deviation vectors.')


args = parser.parse_args()


read_dir = args.read_dir
image_base_dir = args.image_dir
save_dir = args.save_dir
num_choices = args.num_choices
model_size = args.model_size
data_mode = args.data_mode
devi_vector_dir = args.devi_vector_dir



model_name = f"google/gemma-3-{model_size}b-it"

save_json_name = f"/gemma3_{model_size}b.json"


print(f"Using model: {model_name}")
print(f"Data mode: {data_mode}")
print(f"Save JSON name: {save_json_name}")



scale_factor = 1.0



model = Gemma3ForConditionalGeneration.from_pretrained(
    model_name,
    device_map="auto"
).eval()


# default processer


processor = AutoProcessor.from_pretrained(model_name)

# print('Processor:', processor)

tokenizer = processor.tokenizer





input_json_path = read_dir + data_mode + ".json"  # 2000 samples now!!!!


output_json_path = save_dir + "/" + data_mode + save_json_name   # check if the path is for 2000 samples!!!!



warnings.filterwarnings("ignore")


device = model.device 
device_map = "auto"




print("\n" + "="*50)
print("Loading pre-calculated embedding statistics...")


stats_dir = devi_vector_dir
clean_model_name = model_name.replace('/', '_')

text_stats_path = os.path.join(stats_dir, f"{clean_model_name}_train_text_stats.pt")
image_stats_path = os.path.join(stats_dir, f"{clean_model_name}_train_image_stats.pt")


try:
    text_stats = torch.load(text_stats_path)
    image_stats = torch.load(image_stats_path)
except FileNotFoundError as e:
    print(f"ERROR: Could not find statistics files. Make sure these paths are correct:\n- {text_stats_path}\n- {image_stats_path}")
    raise e


text_mean_vector = text_stats['mean'].to(device)
text_std_vector = text_stats['std'].to(device)

image_mean_vector = image_stats['mean'].to(device)
image_std_vector = image_stats['std'].to(device)

print("Successfully loaded embedding statistics.")
print(f"Text stats loaded from: {text_stats_path}")
print(f"Image stats loaded from: {image_stats_path}")
print("="*50 + "\n")




ids_a = tokenizer.encode('A', add_special_tokens=False)
if not (len(ids_a) == 1 and isinstance(ids_a[0], int)):
    raise ValueError(f"Token 'A' is not represented by a single token ID. Got: {ids_a}")
token_a_id = ids_a[0]

ids_b = tokenizer.encode('B', add_special_tokens=False)
if not (len(ids_b) == 1 and isinstance(ids_b[0], int)):
    raise ValueError(f"Token 'B' is not represented by a single token ID. Got: {ids_b}")
token_b_id = ids_b[0]

if num_choices == 4:
    ids_c = tokenizer.encode('C', add_special_tokens=False)
    if not (len(ids_c) == 1 and isinstance(ids_c[0], int)):
        raise ValueError(f"Token 'C' is not represented by a single token ID. Got: {ids_c}")
    token_c_id = ids_c[0]

    ids_d = tokenizer.encode('D', add_special_tokens=False)
    if not (len(ids_d) == 1 and isinstance(ids_d[0], int)):
        raise ValueError(f"Token 'D' is not represented by a single token ID. Got: {ids_d}")
    token_d_id = ids_d[0]


print(f"Token ID for 'A': {token_a_id}")
print(f"Token ID for 'B': {token_b_id}")
if num_choices == 4:
    print(f"Token ID for 'C': {token_c_id}")
    print(f"Token ID for 'D': {token_d_id}")




all_choice_token_ids = [token_a_id, token_b_id, token_c_id, token_d_id] if num_choices == 4 else [token_a_id, token_b_id]


confidence_threshold = 0.3


def get_choice_distributions(logits, num_options):

    last_token_logits = logits[0, -1, :]
    
    overall_predicted_token_id = torch.argmax(last_token_logits).item()


    full_vocab_probs = torch.softmax(last_token_logits, dim=-1)
    absolute_probs = full_vocab_probs[all_choice_token_ids]

    max_options = len(all_choice_token_ids)
    forced_choice_probs = torch.zeros(max_options)
 

    valid_token_ids = all_choice_token_ids[:num_options]
    valid_logits = last_token_logits[valid_token_ids]
    

    valid_choice_probs = torch.softmax(valid_logits, dim=-1)
    
    forced_choice_probs[:num_options] = valid_choice_probs


    choice_confidence_score = torch.sum(absolute_probs[:num_options])
    

    if choice_confidence_score < confidence_threshold:
        if num_options > 0:
            uniform_prob = 1.0 / num_options
            forced_choice_probs[:num_options] = uniform_prob


    return forced_choice_probs, absolute_probs, [overall_predicted_token_id]







print('------------------')
print(model_name)
print('output_json_path:', output_json_path)
print('------------------')

with open(input_json_path, 'r') as f:
    data = json.load(f)

count = 0


right_answer = 0





for item in tqdm(data):

    if image_base_dir:
        image_path = os.path.join(image_base_dir, item['image'])
        image_ori = Image.open(image_path).convert("RGB")
    else:
        img_bytes = base64.b64decode(item['image'])
        image_ori = Image.open(io.BytesIO(img_bytes))
    

    

    question = None
    for conv in item['conversations']:
        if conv['from'] == 'human':
            question = conv['value'].replace('\n<image>', '')
            break
    
    
    message_dual = [
    {
        "role": "system",
        "content": [{"type": "text", "text": "You are a helpful assistant."}]
    },
    {
        "role": "user",
        "content": [
            {"type": "image", "image": image_ori},
            {"type": "text", "text": question}
        ]
    }
    ]


    inputs = processor.apply_chat_template(
        message_dual,
        tokenize=True,
        add_generation_prompt=True,
        return_dict=True,
        return_tensors="pt"
    ).to(model.device, dtype=torch.bfloat16)





    text_inputs_ids = tokenizer([question], return_tensors="pt").input_ids.to(model.device)
    with torch.no_grad():
        text_embeddings = model.model.get_input_embeddings()(text_inputs_ids)


    # print('text_embeddings: ', text_embeddings.shape)


    with torch.no_grad():
        image_embeds = model.model.get_image_features(inputs['pixel_values'])
        image_embeds = image_embeds.to(text_embeddings.device, text_embeddings.dtype)
        

    # print('image_embeds: ', image_embeds.shape)



    inputs_embeds_full = model.model.get_input_embeddings()(inputs['input_ids'])

    

    special_image_mask = model.model.get_placeholder_mask(inputs['input_ids'], inputs_embeds_full, image_features=image_embeds)
    inputs_embeds_full = inputs_embeds_full.masked_scatter(special_image_mask, image_embeds)


    inputs_embeds_dict = {'inputs_embeds': inputs_embeds_full.clone(), 'attention_mask': inputs['attention_mask'].clone(), 'token_type_ids': inputs['token_type_ids'].clone()}


    assert inputs_embeds_dict['inputs_embeds'].shape[1] == inputs['input_ids'].shape[1]



    text_ind_1 = (inputs['input_ids'][0] == 108).nonzero(as_tuple=True)[0][0].item() + 1
    text_ind_2 = (inputs['input_ids'][0] == 236761).nonzero(as_tuple=True)[0][-1].item()

    image_ind_1 = (inputs['input_ids'][0] == 255999).nonzero(as_tuple=True)[0].item()
    image_ind_2 = (inputs['input_ids'][0] == 256000).nonzero(as_tuple=True)[0].item()

    text_feature = inputs_embeds_full[0, text_ind_1:text_ind_2+1, :]
    image_feature = inputs_embeds_full[0, image_ind_1+1:image_ind_2, :]

    # print('text_feature shape:', text_feature.shape)
    # print('image_feature shape:', image_feature.shape)

    assert text_feature.shape[0] == text_inputs_ids.shape[1] - 1
    assert image_feature.shape[0] == image_embeds.shape[-2]

    text_feature = text_feature.mean(dim=0, keepdim=True)  # Shape: [1, 896]
    image_feature = image_feature.mean(dim=0, keepdim=True)  # Shape: [1, 896]


    text_feature = text_feature[0]
    image_feature = image_feature[0]

    
    inputs_image_masked_dict = {'inputs_embeds': inputs_embeds_full.clone(), 'attention_mask': inputs['attention_mask'].clone(), 'token_type_ids': inputs['token_type_ids'].clone()}

    image_slice = inputs_image_masked_dict['inputs_embeds'][:, image_ind_1+1:image_ind_2, :].clone()

    # image_noise = torch.randn_like(image_slice, device=image_slice.device, dtype=image_slice.dtype) * 3 * DEV_DICT[model_name][0]

    visual_replacement_embeds = torch.randn_like(image_slice) * (image_std_vector * scale_factor) + image_mean_vector

    inputs_image_masked_dict['inputs_embeds'][:, image_ind_1+1:image_ind_2, :] = visual_replacement_embeds.clone()

    


    inputs_text_masked_dict = {'inputs_embeds': inputs_embeds_full.clone(), 'attention_mask': inputs['attention_mask'].clone(), 'token_type_ids': inputs['token_type_ids'].clone()}

    text_slice = inputs_text_masked_dict['inputs_embeds'][:, text_ind_1:text_ind_2+1, :].clone()

    # text_noise = torch.randn_like(text_slice, device=text_slice.device, dtype=text_slice.dtype) * 3 * DEV_DICT[model_name][1]


    lang_replacement_embeds = torch.randn_like(text_slice) * (text_std_vector * scale_factor) + text_mean_vector


    inputs_text_masked_dict['inputs_embeds'][:, text_ind_1:text_ind_2+1, :] = lang_replacement_embeds.clone()



    with torch.no_grad():

        cont = model.generate(**inputs, 
                                max_new_tokens=1,
                                return_dict_in_generate=True,
                                output_scores=True,
                                do_sample=False,
                                top_p=None,
                                top_k=None
        )
                                # temperature=0) 
        

        multi_logits = torch.stack(cont.scores, dim=1)
        # print('multi_logits shape:', multi_logits.shape)
        multi_probs, multi_orig_probs, output_tokens = get_choice_distributions(multi_logits, item["num_options"])



        assert int(cont.sequences[0][-1]) == int(output_tokens[0])

        


        output_text = processor.batch_decode(output_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]



        image_single_output = model.generate(**inputs_text_masked_dict, 
                                max_new_tokens=1,
                                return_dict_in_generate=True,  
                                output_scores=True,  
                                do_sample=False,
                                top_p=None,
                                top_k=None
        )
                                # temperature=0) 

        
        logits_image_single= torch.stack(image_single_output.scores, dim=1)
        # print('visual_logits shape:', visual_logits.shape)

        
        visual_probs, visual_orig_probs, visual_tokens = get_choice_distributions(logits_image_single, item["num_options"])



        text_single_output = model.generate(**inputs_image_masked_dict, 
                                max_new_tokens=1,
                                return_dict_in_generate=True,  # Enable returning additional outputs
                                output_scores=True,  # Request hidden states
                                do_sample=False,
                                top_p=None,
                                top_k=None
        )
                                # temperature=0) 

        
        lang_logits = torch.stack(text_single_output.scores, dim=1)
        

        
        language_two_probs, language_two_orig_probs, lang_tokens = get_choice_distributions(lang_logits, item["num_options"])


        
    
    response = output_text

    
    for conv in item['conversations']:
        if conv['from'] == 'gpt':
            conv['label'] = conv['value']
            conv['value'] = response
            conv['v_feature'] = image_feature.to(dtype=torch.float32, device='cpu').detach().numpy().tolist()
            conv['l_feature'] = text_feature.to(dtype=torch.float32, device='cpu').detach().numpy().tolist()
            conv['v_prob'] = visual_probs.to('cpu').detach().numpy().tolist()
            conv['l_prob'] = language_two_probs.to('cpu').detach().numpy().tolist()
            conv['vl_prob'] = multi_probs.to('cpu').detach().numpy().tolist()
            conv['v_orig_prob'] = visual_orig_probs.to('cpu').detach().numpy().tolist()
            conv['l_orig_prob'] = language_two_orig_probs.to('cpu').detach().numpy().tolist()
            conv['vl_orig_prob'] = multi_orig_probs.to('cpu').detach().numpy().tolist()


            if count % 100 == 0:
                print('------------------')
                for key in ['label', 'value', 'v_prob', 'l_prob', 'vl_prob', 
                            'v_orig_prob', 'l_orig_prob', 'vl_orig_prob']:
                    print(f"{key}: {conv[key]}")
                for key in ['v_feature', 'l_feature']:
                    print(f"{key}: {len(conv[key])}")
                print('------------------')
            break
    
    
    normalized_response = response.lower().strip()
    normalized_label = conv['label'].lower().strip()

    if normalized_response == normalized_label:
        right_answer += 1

    count += 1


accuracy = right_answer / count if count > 0 else 0
print(f"Processed {count} samples. Current accuracy: {accuracy:.4f}")



with open(output_json_path, 'w') as f:
    json.dump(data, f, indent=2)