"""
Dataset classes for loading and managing dense retrieval training data.
"""

import os
import logging
from typing import Dict, List, Tuple, Optional, Set
import random
import numpy as np
from torch.utils.data import Dataset

logger = logging.getLogger(__name__)


def load_collection(collection_path: str) -> Tuple[List[str], Dict[str, str]]:
    """
    Load document collection from TSV file.

    Format: doc_id \t doc_text

    Args:
        collection_path: Path to collection.tsv file

    Returns:
        Tuple of (doc_ids list, doc_id_to_text dict)
    """
    logger.info(f"Loading collection from {collection_path}")
    doc_ids = []
    doc_id_to_text = {}

    with open(collection_path, "r", encoding="utf-8") as f:
        for line_idx, line in enumerate(f):
            line = line.strip()
            if not line:
                continue

            parts = line.split("\t")
            if len(parts) != 2:
                logger.warning(f"Skipping malformed line {line_idx}: {line[:100]}")
                continue

            doc_id, doc_text = parts
            doc_ids.append(doc_id)
            doc_id_to_text[doc_id] = doc_text

    logger.info(f"Loaded {len(doc_ids)} documents")
    return doc_ids, doc_id_to_text


def load_queries(queries_path: str) -> Tuple[List[str], Dict[str, str]]:
    """
    Load queries from TSV file.

    Format: query_id \t query_text

    Args:
        queries_path: Path to queries.tsv file

    Returns:
        Tuple of (query_ids list, query_id_to_text dict)
    """
    logger.info(f"Loading queries from {queries_path}")
    query_ids = []
    query_id_to_text = {}

    with open(queries_path, "r", encoding="utf-8") as f:
        for line_idx, line in enumerate(f):
            line = line.strip()
            if not line:
                continue

            parts = line.split("\t")
            if len(parts) != 2:
                logger.warning(f"Skipping malformed line {line_idx}: {line[:100]}")
                continue

            query_id, query_text = parts
            query_ids.append(query_id)
            query_id_to_text[query_id] = query_text

    logger.info(f"Loaded {len(query_ids)} queries")
    return query_ids, query_id_to_text


def load_qrels(qrels_path: str) -> Dict[str, Dict[str, int]]:
    """
    Load qrels from TSV file.

    Format: query_id \t 0 \t doc_id \t relevance_score

    Args:
        qrels_path: Path to qrels.tsv file

    Returns:
        Dict mapping query_id -> {doc_id: relevance_score}
        Relevance score: 0 = negative, >0 = positive
    """
    logger.info(f"Loading qrels from {qrels_path}")
    qrels = {}

    with open(qrels_path, "r", encoding="utf-8") as f:
        for line_idx, line in enumerate(f):
            line = line.strip()
            if not line:
                continue

            parts = line.split("\t")
            if len(parts) != 4:
                logger.warning(
                    f"Skipping malformed qrels line {line_idx}: {line[:100]}"
                )
                continue

            query_id, _, doc_id, relevance = parts
            relevance = int(relevance)

            if query_id not in qrels:
                qrels[query_id] = {}
            qrels[query_id][doc_id] = relevance

    logger.info(f"Loaded qrels for {len(qrels)} queries")
    return qrels


class DenseRetrievalDataset(Dataset):
    """
    Dataset for dense retrieval training.

    Each item contains:
    - query_id
    - query_text
    - positive_doc_ids (list)
    - positive_doc_texts (list)
    - mined_negative_doc_ids (list)
    - mined_negative_doc_texts (list)
    - sampled_negative_doc_ids (list, optional - set by negative sampler)
    - sampled_negative_doc_texts (list, optional - set by negative sampler)
    """

    def __init__(
        self,
        query_ids: List[str],
        query_id_to_text: Dict[str, str],
        doc_id_to_text: Dict[str, str],
        qrels: Dict[str, Dict[str, int]],
        max_positives: int = 1,
        max_mined_negatives: int = 1,
        query_order: Optional[List[str]] = None,
        sampled_negatives: Optional[Dict[str, List[str]]] = None,
        seed: int = 42,
    ):
        """
        Initialize the dataset.

        Args:
            query_ids: List of all query IDs
            query_id_to_text: Mapping from query ID to query text
            doc_id_to_text: Mapping from doc ID to doc text
            qrels: Qrels dict mapping query_id -> {doc_id: relevance}
            max_positives: Maximum number of positives per query
            max_mined_negatives: Maximum number of mined negatives per query
            query_order: Optional ordering of queries (from batch sampler)
            sampled_negatives: Optional sampled negatives per query (from negative sampler)
            seed: Random seed for sampling
        """
        self.query_id_to_text = query_id_to_text
        self.doc_id_to_text = doc_id_to_text
        self.qrels = qrels
        self.max_positives = max_positives
        self.max_mined_negatives = max_mined_negatives
        self.seed = seed

        # Use provided query order or default to input order
        self.query_ids = query_order if query_order is not None else query_ids

        # Store sampled negatives
        self.sampled_negatives = (
            sampled_negatives if sampled_negatives is not None else {}
        )

        # Validate that all queries exist
        self._validate_queries()

        logger.info(f"Initialized dataset with {len(self.query_ids)} queries")
        logger.info(f"Max positives per query: {max_positives}")
        logger.info(f"Max mined negatives per query: {max_mined_negatives}")
        if sampled_negatives:
            logger.info(
                f"Sampled negatives provided for {len(sampled_negatives)} queries"
            )

    def _validate_queries(self):
        """Validate that all queries have qrels and text."""
        valid_query_ids = []
        for qid in self.query_ids:
            if qid not in self.query_id_to_text:
                logger.warning(f"Query {qid} not found in query texts, skipping")
                continue
            if qid not in self.qrels:
                # logger.warning(f"Query {qid} not found in qrels, skipping")
                continue

            # Check that query has at least one positive document
            query_qrels = self.qrels[qid]
            has_positive = any(rel > 0 for rel in query_qrels.values())
            if not has_positive:
                # logger.warning(f"Query {qid} has no positive documents, skipping")
                continue

            valid_query_ids.append(qid)

        self.query_ids = valid_query_ids
        logger.info(f"Validated {len(self.query_ids)} queries")

    def _sample_items(self, items: List, max_count: int, query_id: str) -> List:
        """
        Sample max_count items from a list. If list has fewer items, return all.
        Uses deterministic sampling based on query_id and seed.

        Args:
            items: List of items to sample from
            max_count: Maximum number of items to sample
            query_id: Query ID for deterministic sampling

        Returns:
            Sampled list of items
        """
        if len(items) <= max_count:
            return items

        # Create deterministic random state based on query_id and seed
        rng = random.Random(hash((self.seed, query_id)))
        sampled = rng.sample(items, max_count)
        return sampled

    def __len__(self) -> int:
        return len(self.query_ids)

    def __getitem__(self, idx: int) -> Dict:
        """
        Get a training example.

        Returns:
            Dict with keys:
                - query_id: str
                - query_text: str
                - positive_doc_ids: List[str]
                - positive_doc_texts: List[str]
                - mined_negative_doc_ids: List[str]
                - mined_negative_doc_texts: List[str]
                - sampled_negative_doc_ids: List[str]
                - sampled_negative_doc_texts: List[str]
        """
        query_id = self.query_ids[idx]
        query_text = self.query_id_to_text[query_id]

        # Get all docs for this query from qrels
        query_qrels = self.qrels[query_id]

        # Separate positives (relevance > 0) and negatives (relevance == 0)
        positive_doc_ids = [doc_id for doc_id, rel in query_qrels.items() if rel > 0]
        mined_negative_doc_ids = [
            doc_id for doc_id, rel in query_qrels.items() if rel == 0
        ]

        # Sample if we have more than max
        positive_doc_ids = self._sample_items(
            positive_doc_ids, self.max_positives, query_id
        )
        mined_negative_doc_ids = self._sample_items(
            mined_negative_doc_ids, self.max_mined_negatives, query_id
        )

        # Get texts for positives and mined negatives
        positive_doc_texts = [
            self.doc_id_to_text.get(doc_id, "") for doc_id in positive_doc_ids
        ]
        mined_negative_doc_texts = [
            self.doc_id_to_text.get(doc_id, "") for doc_id in mined_negative_doc_ids
        ]

        # Get sampled negatives if available
        sampled_negative_doc_ids = self.sampled_negatives.get(query_id, [])
        sampled_negative_doc_texts = [
            self.doc_id_to_text.get(doc_id, "") for doc_id in sampled_negative_doc_ids
        ]

        return {
            "query_id": query_id,
            "query_text": query_text,
            "positive_doc_ids": positive_doc_ids,
            "positive_doc_texts": positive_doc_texts,
            "mined_negative_doc_ids": mined_negative_doc_ids,
            "mined_negative_doc_texts": mined_negative_doc_texts,
            "sampled_negative_doc_ids": sampled_negative_doc_ids,
            "sampled_negative_doc_texts": sampled_negative_doc_texts,
        }

    def update_sampled_negatives(self, sampled_negatives: Dict[str, List[str]]):
        """
        Update sampled negatives for the dataset.

        Args:
            sampled_negatives: Dict mapping query_id -> list of negative doc_ids
        """
        self.sampled_negatives = sampled_negatives
        logger.info(f"Updated sampled negatives for {len(sampled_negatives)} queries")

    def update_query_order(self, query_order: List[str]):
        """
        Update the order of queries in the dataset.

        Args:
            query_order: New ordering of query IDs
        """
        self.query_ids = query_order
        self._validate_queries()
        logger.info(
            f"Updated query order, dataset now has {len(self.query_ids)} queries"
        )
