import argparse
import json
from pathlib import Path
from typing import List, Dict, Any

import torch
from datasets import Dataset, concatenate_datasets
from transformers import (
    BertTokenizerFast,
    BertForTokenClassification,

)
class BERT_Grounder:
    """Predict `prefix_target` for a (prefix, sentence) pair."""

    def __init__(self, model_dir: str | Path):
        self.tokenizer = BertTokenizerFast.from_pretrained(model_dir)
        self.model = BertForTokenClassification.from_pretrained(model_dir)
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    @torch.no_grad()
    def ground(self, sentence, prefix) -> List[int]:
        words = sentence if isinstance(sentence, list) else sentence.strip().split()
        combined = prefix + words
        enc = self.tokenizer(combined, is_split_into_words=True, return_tensors="pt").to(self.device)
        preds = self.model(**enc).logits.argmax(-1)[0].tolist()
        word_ids = enc.word_ids()

        max_p = len(prefix) - 1
        prefix_preds = [0] * len(prefix)
        for sub_idx, w_idx in enumerate(word_ids):
            if w_idx is None or w_idx > max_p:
                continue
            if preds[sub_idx] == 1:  # positive
                prefix_preds[w_idx] = 1
        return prefix_preds