import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import sys
import argparse
import transformers
from transformers import AutoProcessor
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration

from train import MyDataArguments, DataCollatorForSupervisedDataset
from src.uav_vln.dataset import LazySupervisedDataset
from src.uav_vln.constants import IGNORE_INDEX

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

@torch.no_grad()
def main():
    # -------------------
    # Configuration
    # -------------------

    teacher_model_name = "model/Qwen2.5-VL-32B-Instruct"
    
    data_path = "dataset/sftdatabbox.json"     
    save_dir = "dataset/teacher_attn_map/" 
                              
    os.makedirs(save_dir, exist_ok=True)

    # -------------------
    # Load teacher model
    # -------------------
    with torch.no_grad():
        print("Loading teacher model...")
        teacher_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            teacher_model_name,
            torch_dtype=torch.float16,
            attn_implementation="eager",
            device_map="auto"
        )
        #teacher_model.to(device)
        teacher_model.eval()
        teacher_model.config.use_cache = False

    # -------------------
    # Load tokenizer & processor
    # -------------------
    processor = AutoProcessor.from_pretrained(teacher_model_name)

    # -------------------
    # Dataset
    # -------------------

    data_args = MyDataArguments()
    data_args.teacher_map_path = None
    data_args.neg_map_path = None

    data_args.resize_ratio = 4

    dataset = LazySupervisedDataset(
        tokenizer=processor.tokenizer,
        processor=processor,
        data_path=data_path,
        data_args=data_args  
    )

    data_collator = DataCollatorForSupervisedDataset(tokenizer=processor.tokenizer)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=data_collator,
        num_workers=4
    )

    for idx, batch in enumerate(tqdm(dataloader)):
        batch = {k: v.to(teacher_model.device) if torch.is_tensor(v) else v for k, v in batch.items()}

        outputs = teacher_model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            pixel_values=batch["pixel_values"],
            image_grid_thw=batch["image_grid_thw"],
            output_attentions=True,
            return_dict=True,
        )

        # Get the first layer attention map
        attn_map = outputs['attentions'][0]
        final_attn = attn_map[0].mean(dim=0)
        # ------------------
        # Save file
        # -------------------
        image_path = batch["image_path"][0]
        image_path = os.path.splitext(os.path.basename(image_path))[0]
        torch.save(final_attn, save_dir + f"{image_path}.pt")

    print(f"Attention scores have been saved to {save_dir}")


if __name__ == "__main__":
    main()
