
import argparse, os, json, re
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset

try:
    from src_files.utils.logger import setup_logger
except Exception:
    def setup_logger(*args, **kwargs):
        class _L:
            def info(self, *a, **k): print(*a)
        return _L()

try:
    from models import MODEL_CLASSES, MODEL_PATH
except Exception:
    from models import MODEL_CLASSES
    MODEL_PATH = {}

MODEL_PATH = {
    "Qwen/Qwen2.5-VL-7B-Instruct": "qwen2_5vl_7b",
    "Qwen/Qwen2-VL-7B-Instruct": "qwen2vl_7b",
    "OpenGVLab/InternVL2_5-8B": "internvl2_5_8b",
    "OpenGVLab/InternVL3-8B": "internvl3_8b"
}

from data.mlc import MLCDataset

def load_cls_names(file_path):
    classes = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            classes.append(line.strip().lower())
    return classes

def shard_indices_contiguous(n, k, s):
    a = (n * k) // s
    b = (n * (k + 1)) // s
    return list(range(a, b))

def shard_indices_roundrobin(n, k, s):
    return list(range(k, n, s))

def build_out_dir(args):

    out_dir = os.path.join(args.output, args.data_name, f"run_shards_group_{args.class_group}_{MODEL_PATH[args.model_path]}_{args.data_type}_{args.seed}")
    os.makedirs(out_dir, exist_ok=True)
    return out_dir

def _split_text_to_tokens(s):
    text = re.sub(r"[\n\t]+", " ", str(s).lower())
    toks = re.split(r"[;,/|]|\band\b|\bor\b", text)
    toks = [re.sub(r"[^a-z0-9]+", " ", t).strip() for t in toks]
    toks = [t for t in toks if t]
    return toks

def _merge_answer_strings(str_list):
    bag = []
    for s in str_list:
        bag.extend(_split_text_to_tokens(s))
    uniq = []
    seen = set()
    for t in bag:
        if t not in seen:
            uniq.append(t); seen.add(t)
    return ", ".join(uniq)

def _norm(s: str) -> str:
    return re.sub(r"\s+", " ", str(s).strip().lower())

def resolve_group_json(args):


    if args.class_group in ('coo', 'discoo'):
        fname = f"{args.data_name}_{args.class_group}_cls_groups.json"
        candidates = []
        if args.groups_dir:
            candidates.append(os.path.join(args.groups_dir, fname))
        candidates.append(os.path.join(os.getcwd(), fname))
        try:
            script_dir = os.path.dirname(__file__)
            candidates.append(os.path.join(script_dir, fname))
        except Exception:
            pass
        if args.data_path:
            candidates.append(os.path.join(args.data_path, fname))
        for p in candidates:
            if p and os.path.exists(p):
                return p, args.class_group
        raise FileNotFoundError(f"Cannot locate '{fname}'. Checked: " + ", ".join(candidates))


    if args.class_group_json:
        return args.class_group_json, None

    return None, None

def load_class_groups(path):

    with open(path, 'r', encoding='utf-8') as f:
        raw = json.load(f)
    groups = raw.get("groups", {})
    norm = {str(k): [_norm(x) for x in v] for k, v in groups.items()}
    return norm

def select_candidate_classes(args, cls_names):

    json_path, mode = resolve_group_json(args)
    if json_path:
        groups = load_class_groups(json_path)
        gid = str(args.class_group_id)
        grp = groups.get(gid, [])
        if not grp:
            raise ValueError(f"Group id {gid} not found in {json_path}. "
                             f"Available: {sorted(groups.keys(), key=lambda x:int(x))}")
        ds_set = set(cls_names)
        selected = [c for c in grp if c in ds_set]
        if not selected:
            raise ValueError(f"Group {gid} from {json_path} has no overlap with dataset classes.")

        if mode in ('coo', 'discoo'):
            return selected, f"cgroup_{mode}{int(gid):02d}"
        else:
            return selected, f"cgroup{int(gid):02d}"


    if args.class_split and args.class_split > 1:
        C = len(cls_names)
        if args.class_scheme == 'contiguous':
            cls_idx = shard_indices_contiguous(C, args.class_shard_id, args.class_split)
        else:
            cls_idx = shard_indices_roundrobin(C, args.class_shard_id, args.class_split)
        selected = [cls_names[i] for i in cls_idx] if len(cls_idx) else []
        return selected, f"cshard{args.class_shard_id:02d}_of_{args.class_split:02d}"

    return list(cls_names), "call"

def run_shard(args):
    if args.seed is not None:
        torch.manual_seed(args.seed)

    out_dir = build_out_dir(args)
    logger = setup_logger(out_dir, color=False, name="SAMPLE-SHARD")
    logger.info("Command: " + ' '.join(os.sys.argv))

    dataset = MLCDataset(args.data_name, args.data_path)
    N = len(dataset)
    if args.num_shards <= 0:
        raise ValueError("--num_shards must be > 0")
    if not (0 <= args.shard_id < args.num_shards):
        raise ValueError("--shard_id must be in [0, num_shards-1]")

    if args.scheme == 'contiguous':
        idx = shard_indices_contiguous(N, args.shard_id, args.num_shards)
    else:
        idx = shard_indices_roundrobin(N, args.shard_id, args.num_shards)

    if len(idx) == 0:
        logger.info(f"[warn] shard {args.shard_id}/{args.num_shards} has no samples under scheme={args.scheme}; nothing to do.")
        return

    subset = Subset(dataset, idx)
    loader = DataLoader(subset, batch_size=args.batch_size, shuffle=False)

    ModelClass = MODEL_CLASSES[args.model_type]
    if ModelClass is None:
        raise ImportError(f"Model '{args.model_type}' is not available in this environment.")
    try:
        model_runner = ModelClass(model_type=args.model_type, model_path=args.model_path, max_new_tokens=args.max_new_tokens)
    except TypeError:
        model_runner = ModelClass(args.model_path)
    try:
        model_runner.load_model_and_processor()
    except Exception:
        pass

    if args.data_name == 'objects365':
        cls_names_path = os.path.join(args.data_path, args.data_name, 'o251', f"{args.data_name}_cls_names.txt")
    else:
        cls_names_path = os.path.join(args.data_path, args.data_name, f"{args.data_name}_cls_names.txt")
    with open(cls_names_path, 'r', encoding='utf-8') as f:
        cls_names = [line.strip().lower() for line in f]

    selected_classes, tag = select_candidate_classes(args, cls_names)

    if args.sample_prompt == 'candidates':
        candidate_str = ", ".join(selected_classes)
        def make_text(n): 
            return [f"What objects are in this image? Candidates: {candidate_str}. Please list only the objects, separated by commas. If there are no candidate objects, please answer NO."] * n
        out_name = f"answers_by_sample_candidates_{tag}_shard{args.shard_id:03d}_of_{args.num_shards:03d}.npy" if tag != "call" \
                   else f"answers_by_sample_candidates_shard{args.shard_id:03d}_of_{args.num_shards:03d}.npy"
    else:
        def make_text(n): 
            return ["What objects are in this image? Please list only the objects, separated by commas."] * n
        out_name = f"answers_by_sample_open_shard{args.shard_id:03d}_of_{args.num_shards:03d}.npy"

    answers = []
    for i, imgs in enumerate(loader):
        batch = {"image": imgs, "text": make_text(len(imgs))}
        qa = model_runner.run_batch_inference(batch)
        answers.extend(qa["answer"])
        if i % args.print_freq == 0:
            logger.info(f"[{i}/{len(loader)}] shard={args.shard_id}/{args.num_shards} qa_sample={qa}")

    np.save(os.path.join(out_dir, out_name), np.asarray(answers, dtype=str))
    np.save(os.path.join(out_dir, f"used_indices_shard{args.shard_id:03d}_of_{args.num_shards:03d}.npy"),
            np.array(idx, dtype=np.int64))

    meta = dict(
        scheme=args.scheme, num_shards=args.num_shards, shard_id=args.shard_id,
        N_total=N, sample_prompt=args.sample_prompt, batch_size=args.batch_size, seed=args.seed,
        class_split=args.class_split, class_shard_id=args.class_shard_id, class_scheme=args.class_scheme,

        class_group=args.class_group, groups_dir=args.groups_dir, class_group_json=args.class_group_json,
        class_group_id=str(args.class_group_id)
    )
    with open(os.path.join(out_dir, f"shard_meta_{args.shard_id:03d}_of_{args.num_shards:03d}.json"), 'w', encoding='utf-8') as f:
        json.dump(meta, f, indent=2)
    logger.info(f"[done] wrote shard outputs to {out_dir}")

def merge_shards(args):
    out_dir = build_out_dir(args)

    metas = []
    for fname in os.listdir(out_dir):
        if fname.startswith("shard_meta_") and fname.endswith(".json"):
            path = os.path.join(out_dir, fname)
            with open(path, 'r', encoding='utf-8') as f:
                metas.append(json.load(f))
    if not metas:
        raise FileNotFoundError(f"No shard_meta_*.json found in {out_dir}. Cannot merge.")
    metas = sorted(metas, key=lambda m: m["shard_id"])
    S = metas[0]["num_shards"]
    scheme = metas[0]["scheme"]
    N = metas[0]["N_total"]

    assert all(m["num_shards"] == S for m in metas)
    assert all(m["scheme"] == scheme for m in metas)
    assert all(m["N_total"] == N for m in metas)
    assert [m["shard_id"] for m in metas] == list(range(S)), "Missing shards"

    def indices_for_shard(k):
        return list(range((N*k)//S, (N*(k+1))//S)) if scheme=='contiguous' else list(range(k, N, S))

    def merge_pattern_to_array(prefix, out_name):
        full = np.empty((N,), dtype=object)
        any_found = False
        for m in metas:
            k = m["shard_id"]
            idx = indices_for_shard(k)
            shard_path = os.path.join(out_dir, f"{prefix}_shard{k:03d}_of_{S:03d}.npy")
            if not os.path.exists(shard_path):
                continue
            any_found = True
            shard_ans = np.load(shard_path, allow_pickle=True)
            if len(shard_ans) != len(idx):
                raise ValueError(f"Shard {k} answers length {len(shard_ans)} != its index count {len(idx)} for {prefix}")
            for pos, ds_i in enumerate(idx):
                full[ds_i] = shard_ans[pos]
        if not any_found:
            return None
        full = np.asarray(full, dtype=str)
        np.save(os.path.join(out_dir, out_name), full)
        return full

   
    merge_pattern_to_array("answers_by_sample_open", "answers_by_sample_open.npy")


    Cs = metas[0].get("class_split", 1)
    if Cs > 1:
        merged_per_cs = []
        for cs in range(Cs):
            prefix = f"answers_by_sample_candidates_cshard{cs:02d}_of_{Cs:02d}"
            arr = merge_pattern_to_array(prefix, f"{prefix}.npy")
            if arr is not None:
                merged_per_cs.append(arr)
        if merged_per_cs:
            fused = np.empty((N,), dtype=object)
            for i in range(N):
                fused[i] = _merge_answer_strings([a[i] for a in merged_per_cs if a is not None])
            fused = np.asarray(fused, dtype=str)
            np.save(os.path.join(out_dir, "answers_by_sample_candidates.npy"), fused)


    present_prefixes = set()
    for fname in os.listdir(out_dir):
        m1 = re.match(r"answers_by_sample_candidates_cgroup_(coo|discoo)(\d+)_shard\d+_of_\d+\.npy$", fname)
        m2 = re.match(r"answers_by_sample_candidates_cgroup(\d+)_shard\d+_of_\d+\.npy$", fname)
        if m1:
            mode, gid = m1.group(1), int(m1.group(2))
            present_prefixes.add(f"answers_by_sample_candidates_cgroup_{mode}{gid:02d}")
        elif m2:
            gid = int(m2.group(1))
            present_prefixes.add(f"answers_by_sample_candidates_cgroup{gid:02d}")

    if present_prefixes:
        merged_arrays = []
        for prefix in sorted(present_prefixes):
            arr = merge_pattern_to_array(prefix, f"{prefix}.npy")
            if arr is not None:
                merged_arrays.append(arr)
        if merged_arrays:
            fused = np.empty((N,), dtype=object)
            for i in range(N):
                fused[i] = _merge_answer_strings([a[i] for a in merged_arrays if a is not None])
            fused = np.asarray(fused, dtype=str)
            np.save(os.path.join(out_dir, "answers_by_sample_candidates.npy"), fused)

    merge_pattern_to_array("answers_by_sample_candidates", "answers_by_sample_candidates.npy")

    np.save(os.path.join(out_dir, "used_indices.npy"), np.arange(N, dtype=np.int64))
    print(f"[done] merged into {out_dir}")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--action', choices=['run_shard','merge'], default='run_shard')

    ap.add_argument('--model_type', choices=MODEL_CLASSES.keys(), required=True)
    ap.add_argument('--model_path', required=True)
    ap.add_argument('--data_name', default='coco2014')
    ap.add_argument('--data_type', default='mlc')
    ap.add_argument('--data_path', required=True)
    ap.add_argument('--output', default='./output')

    ap.add_argument('--batch_size', type=int, default=16)
    ap.add_argument('--sample_prompt', choices=['open','candidates'], default='open')
    ap.add_argument('--seed', type=int, default=1)
    ap.add_argument('--print_freq', type=int, default=1)
    ap.add_argument('--max_new_tokens', type=int, default=64)


    ap.add_argument('--scheme', choices=['contiguous','roundrobin'], default='contiguous')
    ap.add_argument('--num_shards', type=int, default=1)
    ap.add_argument('--shard_id', type=int, default=0)


    ap.add_argument('--class_split', type=int, default=1)
    ap.add_argument('--class_shard_id', type=int, default=0)
    ap.add_argument('--class_scheme', choices=['contiguous','roundrobin'], default='contiguous')


    ap.add_argument('--class_group', choices=['coo','discoo'], default=None,
                    help="Choose a predefined class grouping JSON: 'coo' or 'discoo'. If omitted, grouping is disabled.")
    ap.add_argument('--groups_dir', default=None,
                    help="Directory containing coo_cls_groups.json and discoo_cls_groups.json. "
                         "If omitted, will try CWD, script dir, and --data_path.")
   
    ap.add_argument('--class_group_json', default=None, help=argparse.SUPPRESS)
    ap.add_argument('--class_group_id', type=int, default=1, help='Which group id to run (when grouping is enabled)')

    args = ap.parse_args()

    if args.action == 'run_shard':
        run_shard(args)
    else:
        merge_shards(args)

if __name__ == '__main__':
    main()
