import argparse, json, os
import torch
import numpy as np
import re

from torch.utils.data import DataLoader, Subset
from src_files.utils.logger import setup_logger
from src_files.utils.helper import get_raw_dict

from models import MODEL_CLASSES
from data.mlc import MLCDataset


DATA_CLASS = {
    'mlc': MLCDataset
}

MODEL_PATH = {
    "Qwen/Qwen2.5-VL-7B-Instruct": "qwen2_5vl_7b",
    "Qwen/Qwen2.5-VL-32B-Instruct": "qwen2_5vl_32b",
    "Qwen/Qwen2.5-VL-72B-Instruct": "qwen2_5vl_72b",
    "Qwen/Qwen2-VL-7B-Instruct": "qwen2vl_7b",
    "OpenGVLab/InternVL2_5-8B": "internvl2_5_8b",
    "OpenGVLab/InternVL3-8B": "internvl3_8b",
    "OpenGVLab/InternVL3-14B": "internvl3_14b",
    "OpenGVLab/InternVL3_5-14B": "internvl3_5_14b",
    "OpenGVLab/InternVL3-38B": "internvl3_38b"
}

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type', choices=MODEL_CLASSES.keys(), required=True)
    parser.add_argument('--model_path', required=True)
    parser.add_argument('--data_name', default='coco2014')
    parser.add_argument('--data_type', default='mlc')
    parser.add_argument('--data_path', default='')
    parser.add_argument('--output', default='./output')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--start_id', type=int, required=True)
    parser.add_argument('--end_id', type=int, required=True)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print_freq', type=int, default=1)
    parser.add_argument('--mask_path', default=None, help='Path to a 0/1 mask matrix (npy/txt), shape [num_classes, num_samples] or transposed')

    parser.add_argument('--query_type', default='scat', choices=['scat', 'scat_attr', 'mcat', 'mcat_sib'],
                        help="Prompt style: 'scat' (simple yes/no), 'scat_attr' (uses parent/siblings if available). "
                             "'mcat_plus' and 'dis' fallback to 'scat_attr' if metadata missing.")
    parser.add_argument('--utils_dir', default='',
                        help="Directory that may contain '<data_name>_cat_attr_dict.json'.")
    parser.add_argument('--cat_attr_json', default=None,
                        help="Optional explicit path to '<data_name>_cat_attr_dict.json'. If not set, will try --utils_dir.")
    return parser.parse_args()


def main():
    args = get_args()
    if args.seed is not None:
        torch.manual_seed(args.seed)


    args.output = os.path.join(args.output, args.data_name, f'align_new_new_{args.query_type}_{MODEL_PATH[args.model_path]}_{args.data_type}_{args.seed}')
    args.log_output = os.path.join(args.output, f"{args.start_id}_{args.end_id}")
    os.makedirs(args.log_output, exist_ok=True)

    logger = setup_logger(args.log_output, color=False, name="LVLM")
    logger.info("Command: " + ' '.join(os.sys.argv))
    json.dump(get_raw_dict(args), open(os.path.join(args.output, "config.json"), 'w'), indent=2)

    ModelClass = MODEL_CLASSES[args.model_type]
    if ModelClass is None:
        raise ImportError(f"Model '{args.model_type}' is not available in this environment.")
    model_runner = ModelClass(model_type=args.model_type, model_path=args.model_path)
    model_runner.load_model_and_processor()


    DataClass = DATA_CLASS[args.data_type]
    dataset = DataClass(args.data_name, args.data_path)


    if args.data_name == 'objects365':
        cls_names = load_cls_names(os.path.join(args.data_path, args.data_name, 'o251', f'{args.data_name}_cls_names.txt'))
    else:
        cls_names = load_cls_names(os.path.join(args.data_path, args.data_name, f'{args.data_name}_cls_names.txt'))

    num_classes = len(cls_names)
    try:
        num_samples = len(dataset)
    except TypeError:
        num_samples = dataset.__len__()
    mask = load_mask(args.mask_path, num_classes, num_samples)


    cat_attr = try_load_cat_attr(args)

    # Loop over classes
    for cls_id in range(args.start_id, args.end_id + 1):
        cls_name = cls_names[cls_id]

        cls_indices = np.where(mask[cls_id])[0].tolist()
        cls_answers = []
        if len(cls_indices) == 0:
           
            full = np.full((num_samples,), 'no', dtype='<U3')
            np.save(os.path.join(args.output, f"answer_{cls_id}.npy"), full)
            continue

        subset = Subset(dataset, cls_indices)
        loader = DataLoader(subset, batch_size=args.batch_size, shuffle=False)

        query_text = generate_query_text(args, cls_name, cat_attr)

        for i, imgs in enumerate(loader):
            batch = {"image": imgs, "text": [query_text] * len(imgs)}
            batch_qa = model_runner.run_batch_inference(batch)
            batch_answers = batch_qa['answer']
            cls_answers.extend(batch_answers)
            if i % args.print_freq == 0:
                logger.info(f"[{i}/{len(loader)}] {cls_name} | prompt={query_text} | {batch_qa}")


        full = np.full((num_samples,), 'no', dtype='<U3')
        full[np.array(cls_indices, dtype=np.int64)] = np.array(cls_answers, dtype='<U3')
        np.save(os.path.join(args.output, f"answer_{cls_id}.npy"), full)

# -----------------------
# Prompt generation (ported & robust)
# -----------------------
def generate_query_text(args, cls_name: str, cat_attr: dict) -> str:
    """
    Generate per-class prompt with explicit behaviors:
      - scat:       simple yes/no with class name only
      - scat_attr:  use parent + siblings + description
      - mcat:       use parent only
      - mcat_sib:   use parent + siblings
    Fallbacks:
      - If required fields are missing, degrade gracefully to available info;
        if nothing is available, fall back to simple prompt.
    """
    qt = (args.query_type or 'scat').lower()

    # Minimal yes/no
    def simple(desc):
        return (f"Please only answer yes or no. "
                f"Is there a {desc} in this image?")


    if qt == 'scat':
        return simple(cls_name)


    entry = (cat_attr or {}).get(_norm(cls_name))
    if not entry:
        # No metadata: fall back to simple
        return simple(cls_name)


    parent_raw = entry.get('parent', None)
    siblings_raw = entry.get('siblings', None)
    description_raw = entry.get('description', None)


    parent = _norm(parent_raw) if parent_raw else ''
    description = _norm(description_raw) if description_raw else _norm(cls_name)
    sib_txt = _clean_siblings(siblings_raw, cls_name, limit=8) if siblings_raw else ''


    if qt == 'scat_attr':
        parts = []
        head = (description.capitalize() if description else cls_name.capitalize())

        if parent:
            parts.append(f"{head} is a type of {parent}.")
        if sib_txt:
            parts.append(f"It is not a {sib_txt}.")
        parts.append(f"Carefully examine the image and decide if it truly contains a {description}.")
        parts.append("Answer with only 'yes' or 'no'.")
        return " ".join(parts)


    if qt == 'mcat':
        if parent:
            head = cls_name.capitalize()
            return (f"{head} is a type of {parent}. "
                    f"Carefully examine the image and decide if it contains a {cls_name}. "
                    f"Answer with only 'yes' or 'no'.")
        else:
            return simple(cls_name)


    if qt == 'mcat_sib':

        pieces = []
        head = cls_name.capitalize()
        if parent:
            pieces.append(f"{head} is a type of {parent}.")
        if sib_txt:
            pieces.append(f"It is not a {sib_txt}.")
        if pieces:
            pieces.append(f"Carefully examine the image and decide if it contains a {cls_name}.")
            pieces.append("Answer with only 'yes' or 'no'.")
            return " ".join(pieces)
        else:
            return simple(cls_name)

    return simple(cls_name)



def try_load_cat_attr(args) -> dict:
    """
    Load <data_name>_cat_attr_dict.json from --cat_attr_json or --utils_dir.
    Normalize keys to lower-case for robust lookup.
    """
    candidates = []
    if args.cat_attr_json:
        candidates.append(args.cat_attr_json)

    if args.utils_dir:
        candidates.append(os.path.join(args.utils_dir, f"{args.data_name}_cat_attr_dict.json"))
        if args.data_name == 'objects365':
            candidates.append(os.path.join(args.utils_dir, "objects365_cat_attr_dict.json"))

    for p in candidates:
        if p and os.path.exists(p):
            try:
                with open(p, 'r', encoding='utf-8') as f:
                    raw = json.load(f)
                norm = {}
                for k, v in raw.items():
                    if isinstance(k, str):
                        norm[_norm(k)] = v
                return norm
            except Exception:
                pass
    return None


def _norm(s: str) -> str:
    return re.sub(r"\s+", " ", str(s).strip().lower())

def _expand_slash_terms(term: str):
    """Split things like 'blackboard/whiteboard' into ['blackboard','whiteboard']."""
    parts = [p.strip() for p in re.split(r"/+", str(term))]
    return [p for p in parts if p]

def _clean_siblings(raw_siblings, cls_name: str, limit: int = 8):
    """
    Normalize, expand slash-joined items, drop self/dupes, and cap the list length.
    Returns a comma-joined string like: 'sibling1, sibling2, sibling3'
    """
    cls_key = _norm(cls_name)
    bag = []
    for s in (raw_siblings or []):
        for p in _expand_slash_terms(s):
            if _norm(p) and _norm(p) != cls_key:
                bag.append(_norm(p))

    seen = set(); uniq = []
    for x in bag:
        if x not in seen:
            uniq.append(x); seen.add(x)
    if limit and len(uniq) > limit:
        uniq = uniq[:limit]
    return ", ".join(uniq)

def load_cls_names(file_path):
    classes = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            classes.append(line.strip().lower())
    return classes

def load_mask(mask_path, num_classes, num_samples):
    """Load a 0/1 mask for per-class annotation selection.
    Returns a boolean array of shape (num_classes, num_samples).
    If mask_path is None, returns all-ones.
    Accepts .npy or .txt/.csv; transposed shapes are auto-fixed."""
    if not mask_path:
        return np.ones((num_classes, num_samples), dtype=bool)
    p = str(mask_path)
    ext = os.path.splitext(p)[1].lower()
    if ext == ".npy":
        arr = np.load(p)
    else:
        try:
            arr = np.loadtxt(p, delimiter=',')
        except Exception:
            arr = np.loadtxt(p)
    arr = np.asarray(arr)
    if arr.dtype != np.bool_:
        arr = (arr.astype(np.float32) > 0.5)
    if arr.shape == (num_classes, num_samples):
        return arr
    if arr.shape == (num_samples, num_classes):
        return arr.T
    raise ValueError(f"mask shape {arr.shape} incompatible; expected ({num_classes}, {num_samples}) or transpose")

if __name__ == "__main__":
    main()
