import argparse
import json
from pathlib import Path
from typing import List, Dict, Any

import torch
from datasets import Dataset, concatenate_datasets
from transformers import (
    BertTokenizerFast,
    BertForTokenClassification,

)

class BERT_Lifter:
    """Load a fine‑tuned model and extract lifted propositions from text."""

    def __init__(self, model_dir: str):
        self.tokenizer = BertTokenizerFast.from_pretrained(model_dir)
        self.model = BertForTokenClassification.from_pretrained(model_dir)
        self.model.eval()
        self.id2label = {i: f"prop_{i}" for i in range(1, self.model.num_labels)}

    @torch.no_grad()
    def lift(self, text: str) -> Dict[str, List[str]]:
        # Basic whitespace split to words → better alignment with training.
        words = text.strip().split()
        enc = self.tokenizer(words, is_split_into_words=True, return_tensors="pt")
        logits = self.model(**enc).logits  # (1,seq_len,num_labels)
        preds  = logits.argmax(-1)[0].tolist()  # list[int]

        word_ids = enc.word_ids()
        token_preds = {}
        for sub_id, word_idx in enumerate(word_ids):
            if word_idx is None:
                continue
            label_id = preds[sub_id]
            if label_id == 0:
                continue  # outside any proposition
            token_preds.setdefault(word_idx, label_id)

        # Group words by predicted label id
        props: Dict[int, List[str]] = {}
        for idx, label_id in token_preds.items():
            props.setdefault(label_id, []).append(words[idx])

        # Convert numeric keys to prop_n strings, ordered by id
        return {f"prop_{i}": toks for i, toks in sorted(props.items())}
