import json
import random
import os
from PIL import Image

def get_layers(model):
    """
    Retrieve the transformer layers from a LLaVA-like model.
    """
    lm = getattr(model.model, "language_model", None)
    if lm is not None and hasattr(lm, "layers"):
        return lm.layers
    elif hasattr(model.model, "layers"):
        return model.model.layers
    else:
        raise AttributeError("Cannot find layers in this model structure.")

def get_data_llava(
    image_dir,
    ann_file,          # COCO annotation file, e.g., instances_val2014.json
    image_names,
    n_patching=128,
    n_val=128,
    prompt_template="Please describe the image in detail.",
    patch_size=14,
    target_size=336
):
    """
    Prepare data samples for LLaVA evaluation/training.
    Returns a dict with patching_samples and val_samples lists.
    """
    with open(ann_file, "r") as f:
        coco = json.load(f)
    catid2name = {cat["id"]: cat["name"] for cat in coco["categories"]}
    imgid2fname = {img["id"]: img["file_name"] for img in coco["images"]}

    from collections import defaultdict
    fname2labels = defaultdict(set)
    for ann in coco["annotations"]:
        fname = imgid2fname[ann["image_id"]]
        catname = catid2name[ann["category_id"]]
        fname2labels[fname].add(catname)
    fname2labels = {k: list(v) for k, v in fname2labels.items()}

    total = len(image_names)
    idx = list(range(total))
    random.shuffle(idx)
    patching_idx = idx[:n_patching]
    val_idx = idx[n_patching:n_patching+n_val]

    def build_samples(indices):
        samples = []
        for i in indices:
            fname = image_names[i]
            img_path = os.path.join(image_dir, fname)
            try:
                img = Image.open(img_path).convert("RGB")
            except Exception as e:
                print(f"Cannot read image {img_path}: {e}")
                continue
            # Build prompt with correct number of image tokens
            n_patches = (target_size // patch_size) ** 2
            image_tokens_str = " ".join(["<image>"] * n_patches)
            prompt = f"<|im_start|>user\n{image_tokens_str}\n{prompt_template}\n<|im_end|>\n<|im_start|>assistant\n"
            # COCO ground-truth labels
            gt_label = fname2labels.get(fname, [])
            samples.append({
                "image": img,
                "prompt": prompt,
                "gt_label": gt_label,
                "file_name": fname,
            })
        return samples

    patching_samples = build_samples(patching_idx)
    val_samples = build_samples(val_idx)

    data = {
        "patching_samples": patching_samples,
        "val_samples": val_samples,
    }
    print(f"Number of patching_samples: {len(patching_samples)}, number of val_samples: {len(val_samples)}")
    return data
