import argparse
import torch
import os
import json
from tqdm import tqdm
import sys
import random
from PIL import Image, ImageFilter
import numpy as np

from transformers import set_seed
from transformers import AutoProcessor, LlavaForConditionalGeneration


def process_question_data(file_path):
    results = []
    with open(file_path, 'r', encoding='utf-8') as infile:
        data = json.load(infile)
        for item in data:
            format_type = None
            image_file = None
            
            if 'data_id' in item and all(f'choice_{c.lower()}' in item for c in ['a', 'b', 'c', 'd']):
                format_type = 'format1'
                image_file = f"{item['data_id']}"
                folder_type = 'format1_folder'
            
            elif 'image' in item and 'answer' in item and item['answer'].lower() in ['yes', 'no']:
                format_type = 'format2'
                image_file = item['image']
                folder_type = 'format2_folder'
            
            if not format_type or not image_file:
                print(f"Skipping item with unknown format: {item}")
                continue
            
            if format_type == 'format1':
                question_type = "multiple_choice"
                question = item['question']
                
                choices = f"A. {item['choice_a']}\nB. {item['choice_b']}\nC. {item['choice_c']}\nD. {item['choice_d']}"
                full_question = f"{question}\n{choices}"
                
                all_choices = {
                    'A': item['choice_a'],
                    'B': item['choice_b'],
                    'C': item['choice_c'],
                    'D': item['choice_d']
                }
                
                data_point = {
                    "image": image_file,
                    "text": full_question,
                    "label": item['answer'],
                    "choices": all_choices,
                    "type": question_type,
                    "folder_type": folder_type
                }
                
            elif format_type == 'format2':
                question_type = "yes_no"
                question = item['question']
                
                full_question = question
                
                correct_answer = item['answer']
                if correct_answer.lower() == "yes":
                    correct_answer = "Yes"
                elif correct_answer.lower() == "no":
                    correct_answer = "No"
                
                data_point = {
                    "image": image_file,
                    "text": full_question,
                    "label": correct_answer,
                    "choices": {"Yes": "Yes", "No": "No"},
                    "type": question_type,
                    "folder_type": folder_type
                }
                
            results.append(data_point)
    
    return results

def get_wrong_option(correct_answer, choices, question_type):
    if question_type == "multiple_choice":
        options = ['A', 'B', 'C', 'D']
        wrong_options = [opt for opt in options if opt != correct_answer]
        return random.choice(wrong_options)
    elif question_type == "yes_no":
        return "No" if correct_answer == "Yes" else "Yes"

def eval_model(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = LlavaForConditionalGeneration.from_pretrained(args.model_path)
    model.to(device)
    processor = AutoProcessor.from_pretrained(args.model_path)

    if hasattr(model, 'config'):
        model.config.output_hidden_states = True

    questions = process_question_data(args.question_file)

    all_layer_wise_activations = []
    all_head_wise_activations = []
    processed_count = 0
   
    for line in tqdm(questions):
        image_file = line["image"]
        qs = line["text"]
        gt_answer = line["label"]
        question_type = line["type"]
        folder_type = line["folder_type"]
        
        wrong_ans = get_wrong_option(gt_answer, line["choices"], question_type)
        
        if folder_type == 'format1_folder':
            image_folder = args.format1_image_folder
        elif folder_type == 'format2_folder':
            image_folder = args.format2_image_folder
        else:
            print(f"Unknown folder type: {folder_type}, skipping...")
            continue
            
        try:
            image = Image.open(os.path.join(image_folder, image_file))
        except FileNotFoundError:
            print(f"Image not found: {os.path.join(image_folder, image_file)}, skipping...")
            continue
            
        prompt = f"USER: <image>\n{qs} ASSISTANT: {wrong_ans}"
        
        inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)

        HEADS = [args.layer_name_template.format(layer_idx=i) for i in range(args.num_layers)]
        outputs_dict = {}

        def hook_fn(module, input, output):
            if module not in outputs_dict:
                outputs_dict[module] = output.cpu()
                
        layer_names = HEADS
        layers = []
        for name in layer_names:
            module = dict([*model.named_modules()]).get(name)
            if module:
                layers.append(module)
            else:
                print(f"Module not found: {name}")
                
        hook_handles = [layer.register_forward_hook(hook_fn) for layer in layers]

        with torch.no_grad():
            output = model(
                **inputs,
                output_hidden_states=True
            )
            for handle in hook_handles:
               handle.remove()

            if len(outputs_dict) == len(layer_names):
                attention_output = tuple(outputs_dict.values())
                attention_output = torch.stack(attention_output, dim=0).detach().cpu().squeeze().numpy()
                hidden_states = output.hidden_states
                hidden_states = torch.stack(hidden_states, dim=0).squeeze()
                hidden_states = hidden_states.detach().cpu().numpy()
                layer_wise_activations = hidden_states

                head_wise_activations = attention_output
                
                if len(layer_wise_activations.shape) == 3:
                    all_layer_wise_activations.append(layer_wise_activations[:,-1,:].copy())
                else:
                    print(f"Unexpected layer_wise_activations shape: {layer_wise_activations.shape}")
                    continue
                    
                if len(head_wise_activations.shape) == 2:
                    all_head_wise_activations.append(head_wise_activations.copy())
                elif len(head_wise_activations.shape) == 3:
                    all_head_wise_activations.append(head_wise_activations[:,-1,:].copy())
                else:
                    print(f"Unexpected head_wise_activations shape: {head_wise_activations.shape}")
                    continue
                
                processed_count += 1
            else:
                print(f"Missing outputs for some layers, skipping this question")
    
    print(f"Successfully processed {processed_count} questions")
            
    np.save(args.output, all_head_wise_activations)
    if args.save_layer_activations:
        np.save(args.layer_output, all_layer_wise_activations)
    
    metadata = {
        "total_questions": len(questions),
        "processed_questions": processed_count,
        "question_types": {
            "multiple_choice": sum(1 for q in questions if q["type"] == "multiple_choice"),
            "yes_no": sum(1 for q in questions if q["type"] == "yes_no")
        },
        "model_path": args.model_path,
        "format1_image_folder": args.format1_image_folder,
        "format2_image_folder": args.format2_image_folder,
        "question_file": args.question_file,
        "layer_name_template": args.layer_name_template,
        "num_layers": args.num_layers
    }
    
    metadata_path = os.path.splitext(args.output)[0] + "_metadata.json"
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="path")
    parser.add_argument("--format1-image-folder", type=str, default="path",
                        help="Image folder for format 1 (data_id based images)")
    parser.add_argument("--format2-image-folder", type=str, default="path",
                        help="Image folder for format 2 (direct image field)")
    parser.add_argument("--question_file", type=str, default="path")
    parser.add_argument("--output", type=str, default='path')
    parser.add_argument("--layer_output", type=str, default='no')
    parser.add_argument("--save_layer_activations", action="store_true", default=True)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--layer-name-template", type=str,
                        required=True,
                        help="Layer name template, use {layer_idx} as placeholder for layer index")
    parser.add_argument("--num-layers", type=int, required=True,
                        help="Number of layers to hook")
    
    args = parser.parse_args()
    set_seed(args.seed)
    random.seed(args.seed)
    eval_model(args)