import json
import torch
from torch.utils.data import Dataset
import random
from sentence_transformers import SentenceTransformer

random.seed(42)

class TextPairDataset(Dataset):
    def __init__(self, json_file, model_name, limit=0, negative_types=None):

        self.model = SentenceTransformer(model_name)
        self.tokenizer = self.model.tokenizer

        with open(json_file, "r", encoding="utf-8") as f:
            raw_data = json.load(f)

        self.samples = []
        if negative_types is None:
            negative_types = [
                "neg_type_1_tokens",
                "neg_type_2_tokens",
                "neg_type_3_tokens",
                "neg_type_4_tokens"
            ]

        for item in raw_data:
            reason = item["reason"]

            pos_list = item.get("pos_token_texts_list", [])
            for pos_tokens in pos_list:
                pos_ids = self.tokenizer.convert_tokens_to_ids(pos_tokens)
                self.samples.append({
                    "token_ids": pos_ids,
                    "reason": reason,
                    "label": 1
                })

            for neg_type in negative_types:
                neg_tokens = item.get("negatives", {}).get(neg_type, [])
                if neg_tokens:
                    neg_ids = self.tokenizer.convert_tokens_to_ids(neg_tokens)
                    self.samples.append({
                        "token_ids": neg_ids,
                        "reason": reason,
                        "label": 0
                    })

        if limit > 0:
            self.samples = random.sample(self.samples, min(limit, len(self.samples)))

        random.shuffle(self.samples)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        return sample["token_ids"], sample["reason"], sample["label"]
