# merged_training/dataset.py

import random
from typing import List, Tuple

from datasets import load_dataset, load_from_disk
from torch.utils.data import Dataset
import os

from tevatron.retriever.arguments import DataArguments # Using existing args for simplicity

import logging
logger = logging.getLogger(__name__)

class JointTrainDataset(Dataset):
    """
    Dataset for joint training. It provides raw data in the format of 
    (query_text, [passage_text_1, passage_text_2...]), where the first
    passage is always the positive one. Prefixes are handled by the collator.
    """
    def __init__(self, data_args: DataArguments, trainer=None):
        self.data_args = data_args
        self.trainer = trainer
        
        if os.path.isdir(self.data_args.dataset_name):
            logger.info(f"Loading dataset from disk: {self.data_args.dataset_name}")
            self.train_data = load_from_disk(self.data_args.dataset_name)
        else:
            logger.info(f"Loading dataset from Hugging Face Hub: {self.data_args.dataset_name}")
            self.train_data = load_dataset(
                self.data_args.dataset_name,
                self.data_args.dataset_config,
                data_files=self.data_args.dataset_path,
                split=self.data_args.dataset_split,
                cache_dir=self.data_args.dataset_cache_dir,
                trust_remote_code=True
            )

        if self.data_args.dataset_number_of_shards > 1:
            self.train_data = self.train_data.shard(
                num_shards=self.data_args.dataset_number_of_shards,
                index=self.data_args.dataset_shard_index,
            )
        
        logger.info(f"Loaded {len(self.train_data)} training examples.")

    def __len__(self):
        return len(self.train_data)

    def __getitem__(self, item) -> Tuple[str, List[str]]:
        group = self.train_data[item]
        epoch = int(self.trainer.state.epoch) if self.trainer and self.trainer.state else 0
        _hashed_seed = hash(item + self.trainer.args.seed)

        query_text = group['query']
        passages = []
        
        group_positives = group['positive_passages']
        group_negatives = group['negative_passages']

        # Select one positive passage
        if self.data_args.positive_passage_no_shuffle or len(group_positives) == 1:
            pos_psg = group_positives[0]
        else:
            pos_psg = group_positives[(_hashed_seed + epoch) % len(group_positives)]
        
        # Combine title and text for the positive passage
        pos_text = f"{pos_psg.get('title', '').strip()} {pos_psg['text'].strip()}".strip()
        passages.append(pos_text)

        # Select negative passages
        negative_size = self.data_args.train_group_size - 1
        if len(group_negatives) < negative_size:
            negs = random.choices(group_negatives, k=negative_size)
        elif self.data_args.train_group_size == 1:
            negs = []
        elif self.data_args.negative_passage_no_shuffle:
            negs = group_negatives[:negative_size]
        else:
            _offset = (epoch * negative_size) % len(group_negatives)
            shuffled_negs = random.Random(_hashed_seed).sample(group_negatives, len(group_negatives))
            negs = (shuffled_negs * 2)[_offset: _offset + negative_size]

        for neg_psg in negs:
            # Combine title and text for negative passages
            neg_text = f"{neg_psg.get('title', '').strip()} {neg_psg['text'].strip()}".strip()
            passages.append(neg_text)

        return query_text, passages