"""
Utility functions for data management.
"""

import os
import logging
from typing import Dict, List

logger = logging.getLogger(__name__)


def save_training_qrels(
    output_path: str,
    query_ids: List[str],
    query_id_to_positives: Dict[str, List[str]],
    query_id_to_mined_negatives: Dict[str, List[str]],
    query_id_to_sampled_negatives: Dict[str, List[str]],
):
    """
    Save training qrels to a TSV file, grouped by query.

    Format for each query group:
        query_id \t positive_doc_id1,positive_doc_id2,... \t mined_negative_doc_id1,mined_negative_doc_id2,... \t sampled_negative_doc_id1,sampled_negative_doc_id2,...

    Args:
        output_path: Path to save the qrels file
        query_ids: Ordered list of query IDs (from batch sampler)
        query_id_to_positives: Mapping from query_id to list of positive doc_ids
        query_id_to_mined_negatives: Mapping from query_id to list of mined negative doc_ids
        query_id_to_sampled_negatives: Mapping from query_id to list of sampled negative doc_ids
    """
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    with open(output_path, "w", encoding="utf-8") as f:
        for query_id in query_ids:
            positives = query_id_to_positives.get(query_id, [])
            mined_negs = query_id_to_mined_negatives.get(query_id, [])
            sampled_negs = query_id_to_sampled_negatives.get(query_id, [])

            # Join doc IDs with commas
            positives_str = ",".join(positives) if positives else ""
            mined_negs_str = ",".join(mined_negs) if mined_negs else ""
            sampled_negs_str = ",".join(sampled_negs) if sampled_negs else ""

            # Write as tab-separated
            f.write(
                f"{query_id}\t{positives_str}\t{mined_negs_str}\t{sampled_negs_str}\n"
            )

    logger.info(f"Saved training qrels for {len(query_ids)} queries to {output_path}")


def load_training_qrels(qrels_path: str) -> tuple:
    """
    Load training qrels from a previously saved file.

    Args:
        qrels_path: Path to the training qrels file

    Returns:
        Tuple of (query_ids, query_id_to_positives, query_id_to_mined_negatives, query_id_to_sampled_negatives)
    """
    query_ids = []
    query_id_to_positives = {}
    query_id_to_mined_negatives = {}
    query_id_to_sampled_negatives = {}

    with open(qrels_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            parts = line.split("\t")
            if len(parts) != 4:
                logger.warning(f"Skipping malformed line: {line[:100]}")
                continue

            query_id, positives_str, mined_negs_str, sampled_negs_str = parts
            query_ids.append(query_id)

            # Parse doc IDs
            query_id_to_positives[query_id] = (
                positives_str.split(",") if positives_str else []
            )
            query_id_to_mined_negatives[query_id] = (
                mined_negs_str.split(",") if mined_negs_str else []
            )
            query_id_to_sampled_negatives[query_id] = (
                sampled_negs_str.split(",") if sampled_negs_str else []
            )

    logger.info(f"Loaded training qrels for {len(query_ids)} queries from {qrels_path}")
    return (
        query_ids,
        query_id_to_positives,
        query_id_to_mined_negatives,
        query_id_to_sampled_negatives,
    )
