"""
Universal data loader for contextual bandit experiments.

This module provides a flexible data loader that can handle different datasets
with standardized embedding columns and configurable query-arm mapping logic.
"""

import hashlib
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple

import torch
from datasets import load_dataset, load_from_disk
from tqdm import tqdm


class QueryArmMapper(ABC):
    """Abstract base class for mapping queries to correct arms."""

    @abstractmethod
    def extract_correct_arms(
        self, query: Dict[str, Any], tool_id_to_index: Dict[str, int]
    ) -> List[int]:
        """
        Extract correct arm indices for a query.

        Args:
            query: Query data dictionary
            tool_id_to_index: Mapping from tool IDs to indices

        Returns:
            List of correct arm indices for this query
        """
        pass


class ToolRetQueryArmMapper(QueryArmMapper):
    """Query-arm mapper for ToolRet dataset."""

    def extract_correct_arms(
        self, query: Dict[str, Any], tool_id_to_index: Dict[str, int]
    ) -> List[int]:
        """Extract correct arms from ToolRet labels field."""
        correct_tool_ids = []
        if "labels" in query and query["labels"]:
            try:
                jquery = json.loads(query["labels"])
                for label in jquery:
                    if isinstance(label, dict) and "id" in label:
                        tool_id = label["id"]
                        if tool_id in tool_id_to_index:
                            correct_tool_ids.append(tool_id_to_index[tool_id])
            except (json.JSONDecodeError, TypeError):
                pass
        return correct_tool_ids


class UltraToolQueryArmMapper(QueryArmMapper):
    """Query-arm mapper for UltraTool dataset."""

    def extract_correct_arms(
        self, query: Dict[str, Any], tool_id_to_index: Dict[str, int]
    ) -> List[int]:
        """Extract correct arms from UltraTool tool field."""
        oracle_tool = query.get("tool", None)
        if oracle_tool and oracle_tool in tool_id_to_index:
            return [tool_id_to_index[oracle_tool]]
        return []


class NfCorpusQueryArmMapper(QueryArmMapper):
    """Query-arm mapper for NfCorpus dataset."""

    def extract_correct_arms(
        self, query: Dict[str, Any], tool_id_to_index: Dict[str, int]
    ) -> List[int]:
        """Extract correct arms from NfCorpus labels field."""
        tool_id = query.get("_id")
        if tool_id and tool_id in tool_id_to_index:
            return [tool_id_to_index[tool_id]]
        return []


class ArguanaQueryArmMapper(QueryArmMapper):
    """Query-arm mapper for Arguana dataset."""

    def extract_correct_arms(
        self, query: Dict[str, Any], tool_id_to_index: Dict[str, int]
    ) -> List[int]:
        """Extract correct arms from Arguana labels field."""
        tool_id = query.get("corpus-id")
        if tool_id and tool_id in tool_id_to_index:
            return [tool_id_to_index[tool_id]]
        return []


class FiqaQueryArmMapper(QueryArmMapper):
    """Query-arm mapper for Arguana dataset."""

    def extract_correct_arms(
        self, query: Dict[str, Any], tool_id_to_index: Dict[str, int]
    ) -> List[int]:
        """Extract correct arms from Arguana labels field."""
        tool_id = query.get("relevant_corpus_ids")
        if isinstance(tool_id, list):
            rslt = []
            for _ in tool_id:
                if _ in tool_id_to_index:
                    rslt.append(tool_id_to_index[_])
            return rslt
        if tool_id and tool_id in tool_id_to_index:
            return [tool_id_to_index[tool_id]]
        return []


class MultihopQueryArmMapper(QueryArmMapper):
    """Query-arm mapper for MultihopRAG dataset."""

    def extract_correct_arms(
        self, query: Dict[str, Any], tool_id_to_index: Dict[str, int]
    ) -> List[int]:
        """Extract correct arms from MultihopRAG labels field."""
        tool_id = query.get("evidence_list")
        potential_tool_ids = [_["url"] for _ in tool_id]
        rslt = []
        for _ in potential_tool_ids:
            if _ in tool_id_to_index:
                rslt.append(tool_id_to_index[_])
        return rslt


class UniversalBanditDataLoader:
    """
    Universal data loader for contextual bandit experiments.

    Handles loading query and tool datasets with standardized embedding columns
    and flexible query-arm mapping logic for different dataset formats.
    """

    def __init__(
        self,
        tools_dataset_path: str,
        queries_dataset_path: str,
        query_arm_mapper: QueryArmMapper,
        device: torch.device,
        tool_text_field: str,
        query_text_field: str,
        max_queries: Optional[int] = None,
        subset: Optional[str] = None,
        cache_dir: Optional[str] = None,
        require_sort: bool = False,
    ):
        """
        Initialize the universal data loader.

        Args:
            tools_dataset_path: Path to embedded tools dataset
            queries_dataset_path: Path to embedded queries dataset
            query_arm_mapper: Strategy for extracting correct arms from queries
            device: PyTorch device for tensor allocation
            max_queries: Maximum number of queries to load (None for all)
            subset: Dataset subset to use (None for no subset filtering)
            cache_dir: Directory for caching processed data (None for auto)
        """
        self.tools_dataset_path = tools_dataset_path
        self.queries_dataset_path = queries_dataset_path
        self.query_arm_mapper = query_arm_mapper
        self.device = device
        self.max_queries = max_queries
        self.subset = subset
        self.cache_dir = cache_dir or os.path.join(
            os.path.dirname(queries_dataset_path), "universal_bandit_cache"
        )
        self.tool_text_field = tool_text_field
        self.query_text_field = query_text_field
        self.require_sort = require_sort

    def _get_cache_path(self, embedding_model: str) -> str:
        """Compute cache file path based on dataset paths, model, and configuration."""
        os.makedirs(self.cache_dir, exist_ok=True)

        # Get modification times for cache invalidation
        try:
            tools_mtime = os.path.getmtime(self.tools_dataset_path)
        except Exception:
            tools_mtime = 0.0
        try:
            queries_mtime = os.path.getmtime(self.queries_dataset_path)
        except Exception:
            queries_mtime = 0.0

        # Create cache key from all relevant parameters
        key_components = [
            str(self.tools_dataset_path),
            str(tools_mtime),
            str(self.queries_dataset_path),
            str(queries_mtime),
            embedding_model,
            str(self.max_queries),
            str(self.subset),
            self.query_arm_mapper.__class__.__name__,
        ]
        key_str = "|".join(key_components)
        cache_key = hashlib.md5(key_str.encode("utf-8")).hexdigest()
        return os.path.join(self.cache_dir, f"cache_{cache_key}.pt")

    def _get_embedding_column_name(self, embedding_model: str) -> str:
        """Get standardized embedding column name."""
        return f"embedding_{embedding_model}"

    def _load_and_filter_tools(
        self, tools_dataset: Any, embedding_model: str
    ) -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
        """Load and filter tools by embedding model availability."""
        embedding_col = self._get_embedding_column_name(embedding_model)

        filtered_tools = []
        tool_id_to_index = {}

        for tool in tqdm(tools_dataset, desc="Processing tools"):
            # Check if tool has the required embedding
            # Handle both new format (embedding_large) and old format (embed + embed_model)
            has_embedding = False

            if embedding_col in tool and tool[embedding_col] is not None:
                # New format: direct embedding column
                has_embedding = True
            elif "embed" in tool and "embed_model" in tool:
                # Old format: embed column with embed_model field
                if (
                    tool["embed_model"] == embedding_model
                    and tool["embed"] is not None
                ):
                    has_embedding = True

            if has_embedding:
                tool_index = len(filtered_tools)
                tool_id_to_index[tool["id"]] = tool_index
                filtered_tools.append(tool)

        return filtered_tools, tool_id_to_index

    def _load_and_filter_queries(
        self,
        queries_dataset: Any,
        embedding_model: str,
        tool_id_to_index: Dict[str, int],
    ) -> Tuple[List[Dict[str, Any]], List[List[int]]]:
        """Load and filter queries by embedding availability and correct arm mapping."""
        embedding_col = self._get_embedding_column_name(embedding_model)

        valid_queries = []
        query_correct_arms = []

        for query in tqdm(queries_dataset, desc="Processing queries"):
            # Check if query has the required embedding
            # Handle both new format (embedding_large) and old format (embed + embed_model)
            has_embedding = False

            if embedding_col in query and query[embedding_col] is not None:
                # New format: direct embedding column
                has_embedding = True
            elif "embed" in query and "embed_model" in query:
                # Old format: embed column with embed_model field
                if (
                    query["embed_model"] == embedding_model
                    and query["embed"] is not None
                ):
                    has_embedding = True

            if not has_embedding:
                continue

            # Extract correct arms using the mapper strategy
            correct_arms = self.query_arm_mapper.extract_correct_arms(
                query, tool_id_to_index
            )

            # Only include queries that have at least one valid correct arm
            if correct_arms:
                valid_queries.append(query)
                query_correct_arms.append(correct_arms)

        if not valid_queries:
            raise ValueError(
                f"No valid queries found after filtering. "
                f"This might happen when the dataset is truncated too aggressively, "
                f"when queries don't have required embeddings, or when correct arms "
                f"don't exist in the tool set."
            )

        return valid_queries, query_correct_arms

    def _create_embeddings_tensors(
        self,
        filtered_tools: List[Dict[str, Any]],
        embedding_model: str,
        true_embedding_model: str = "large",
        tools_dataset: Any = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Create initial and true embedding tensors from filtered tools."""
        embedding_col = self._get_embedding_column_name(embedding_model)
        true_embedding_col = self._get_embedding_column_name(
            true_embedding_model
        )

        # Initial embeddings (from chosen model)
        initial_embeddings_list = []
        for tool in filtered_tools:
            if embedding_col in tool and tool[embedding_col] is not None:
                # New format: direct embedding column
                initial_embeddings_list.append(tool[embedding_col])
            elif "embed" in tool and "embed_model" in tool:
                # Old format: embed column with embed_model field
                if tool["embed_model"] == embedding_model:
                    initial_embeddings_list.append(tool["embed"])

        initial_embeddings = torch.stack(
            [
                torch.tensor(emb, dtype=torch.float32)
                for emb in initial_embeddings_list
            ]
        ).to(self.device)

        # True embeddings (from true model, fallback to initial if not available)
        # For old format datasets like ToolRet, we need to find the true embedding model version
        true_embeddings_list = []

        if embedding_model == true_embedding_model:
            # If initial and true models are the same, just use initial embeddings
            true_embeddings_list = initial_embeddings_list.copy()
        else:
            # Need to find true embedding model versions
            # Create a mapping from tool ID to true embedding
            tool_id_to_true_embedding = {}

            # First pass: collect all tools with true embedding model
            for tool in tools_dataset:
                if "embed" in tool and "embed_model" in tool:
                    if tool["embed_model"] == true_embedding_model:
                        tool_id_to_true_embedding[tool["id"]] = tool["embed"]

            # Second pass: build true embeddings list aligned with filtered_tools
            for tool in filtered_tools:
                tool_id = tool["id"]
                if tool_id in tool_id_to_true_embedding:
                    true_embeddings_list.append(
                        tool_id_to_true_embedding[tool_id]
                    )
                else:
                    # Fallback to initial embedding
                    if "embed" in tool:
                        true_embeddings_list.append(tool["embed"])
                    elif embedding_col in tool:
                        true_embeddings_list.append(tool[embedding_col])

        true_embeddings = torch.stack(
            [
                torch.tensor(emb, dtype=torch.float32)
                for emb in true_embeddings_list
            ]
        ).to(self.device)

        return initial_embeddings, true_embeddings

    def _create_query_embeddings_tensor(
        self, valid_queries: List[Dict[str, Any]], embedding_model: str
    ) -> torch.Tensor:
        """Create query embeddings tensor from valid queries."""
        embedding_col = self._get_embedding_column_name(embedding_model)

        query_embeddings_list = []
        for query in valid_queries:
            if embedding_col in query and query[embedding_col] is not None:
                # New format: direct embedding column
                query_embeddings_list.append(query[embedding_col])
            elif "embed" in query and "embed_model" in query:
                # Old format: embed column with embed_model field
                if query["embed_model"] == embedding_model:
                    query_embeddings_list.append(query["embed"])

        if not query_embeddings_list:
            raise ValueError(
                f"No valid query embeddings found for embedding model '{embedding_model}'. "
                f"This might happen when the dataset is truncated too aggressively or "
                f"when queries don't have the required embeddings."
            )

        query_embeddings = torch.stack(
            [
                torch.tensor(emb, dtype=torch.float32)
                for emb in query_embeddings_list
            ]
        ).to(self.device)

        return query_embeddings

    def load_data(
        self,
        embedding_model: str = "large",
        true_embedding_model: str = "large",
        add_noise: bool = False,
        noise_std: float = 0.0,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[int]]]:
        """
        Load data for bandit training with strict alignment between
        tools, queries, and their embeddings.

        Args:
            embedding_model: Which embedding model to use for initial embeddings
            true_embedding_model: Which embedding model to use for true embeddings
            add_noise: Whether to add noise to initial embeddings
            noise_std: Standard deviation of noise to add

        Returns:
            Tuple containing:
            - query_embeddings: Query embedding tensor [num_queries, embed_dim]
            - initial_embeddings: Initial tool embeddings [num_tools, embed_dim]
            - true_embeddings: True tool embeddings [num_tools, embed_dim]
            - query_correct_arms: List of correct arm indices for each query
        """
        print(f"Loading bandit data with {embedding_model} embeddings...")

        # Check cache first
        cache_path = self._get_cache_path(embedding_model)
        if os.path.exists(cache_path):
            print(f"Loading cached data from {cache_path}")
            try:
                cache_data = torch.load(cache_path, map_location="cpu")
                query_embeddings = cache_data["query_embeddings"].to(
                    self.device
                )
                initial_embeddings = cache_data["initial_embeddings"].to(
                    self.device
                )
                true_embeddings = cache_data["true_embeddings"].to(self.device)
                query_correct_arms = cache_data["query_correct_arms"]

                self.query_texts = cache_data["query_texts"]
                self.tool_texts = cache_data["tool_texts"]

                print(f"Loaded from cache:")
                print(f"  Query embeddings: {query_embeddings.shape}")
                print(f"  Initial embeddings: {initial_embeddings.shape}")
                print(f"  True embeddings: {true_embeddings.shape}")
                print(f"  Correct arms: {len(query_correct_arms)} lists")
                print(f"  Query texts: {len(self.query_texts)}")
                print(f"  Tool texts: {len(self.tool_texts)}")

                return (
                    query_embeddings,
                    initial_embeddings,
                    true_embeddings,
                    query_correct_arms,
                )
            except Exception as e:
                print(f"Failed to load cache ({e}), recomputing...")

        # Load datasets (prefer load_from_disk for local saved datasets)
        def _load_any(path: str):
            # If it's a local directory containing a saved HF dataset, use load_from_disk
            if os.path.isdir(path):
                # Heuristic: presence of dataset_dict.json implies a saved DatasetDict
                dd_json = os.path.join(path, "dataset_dict.json")
                if os.path.exists(dd_json):
                    return load_from_disk(path)
                # Or any subdirectory with dataset_info.json
                try:
                    for name in os.listdir(path):
                        subdir = os.path.join(path, name)
                        if os.path.isdir(subdir) and os.path.exists(
                            os.path.join(subdir, "dataset_info.json")
                        ):
                            return load_from_disk(path)
                except Exception:
                    pass
                # Fallback to load_from_disk for directories
                return load_from_disk(path)
            # Otherwise, try load_dataset (e.g., HF hub id or script)
            return load_dataset(path)

        tools_dataset = _load_any(self.tools_dataset_path)
        queries_dataset = _load_any(self.queries_dataset_path)

        # Normalize to a single split if a dict with 'train' is returned
        if isinstance(tools_dataset, dict) and "train" in tools_dataset:
            tools_dataset = tools_dataset["train"]
        if isinstance(queries_dataset, dict) and "train" in queries_dataset:
            queries_dataset = queries_dataset["train"]

        if self.require_sort:
            queries_dataset = queries_dataset.sort(self.query_text_field)
            tools_dataset = tools_dataset.sort(self.tool_text_field)

        # Apply subset filtering if specified
        if self.subset is not None:
            if self.subset in tools_dataset:
                tools_dataset = tools_dataset[self.subset]
                queries_dataset = queries_dataset[self.subset]
                print(f"Using '{self.subset}' subset")
            else:
                print(
                    f"Warning: Subset '{self.subset}' not found, using full dataset"
                )

        # ! handle nfcorpus tool id and mhr id
        if not "id" in tools_dataset[0]:
            if "_id" in tools_dataset[0]:
                tools_dataset = tools_dataset.rename_column("_id", "id")
            elif "url" in tools_dataset[0]:
                tools_dataset = tools_dataset.rename_column("url", "id")

        print(
            f"Processing {len(tools_dataset)} tools and {len(queries_dataset)} queries..."
        )

        # Step 1: Filter tools and build ID mapping
        print("Step 1: Filtering tools and building ID mapping...")
        filtered_tools, tool_id_to_index = self._load_and_filter_tools(
            tools_dataset, embedding_model
        )

        if len(filtered_tools) == 0:
            raise ValueError(
                f"No tools found with embedding model '{embedding_model}'"
            )

        print(
            f"Filtered to {len(filtered_tools)} tools with {embedding_model} embeddings"
        )

        # Step 2: Filter queries and build correct arms mapping
        print("Step 2: Processing queries and building correct arms...")
        valid_queries, query_correct_arms = self._load_and_filter_queries(
            queries_dataset, embedding_model, tool_id_to_index
        )

        if len(valid_queries) == 0:
            raise ValueError(
                "No valid queries found with correct tool mappings"
            )

        print(
            f"Found {len(valid_queries)} valid queries with correct tool mappings"
        )

        # Apply max_queries limit if specified
        if (
            self.max_queries is not None
            and len(valid_queries) > self.max_queries
        ):
            valid_queries = valid_queries[: self.max_queries]
            query_correct_arms = query_correct_arms[: self.max_queries]
            print(f"Limited to {self.max_queries} queries")

        # Step 3: Create tensors
        print("Step 3: Creating tensors...")

        # Create embedding tensors
        initial_embeddings, true_embeddings = self._create_embeddings_tensors(
            filtered_tools, embedding_model, true_embedding_model, tools_dataset
        )

        # Add noise to initial embeddings if requested
        if add_noise:
            noise = torch.randn_like(initial_embeddings) * noise_std
            initial_embeddings = initial_embeddings + noise
            print(
                f"Added Gaussian noise (std={noise_std}) to initial embeddings"
            )

        # Create query embeddings
        query_embeddings = self._create_query_embeddings_tensor(
            valid_queries, embedding_model
        )

        print(f"Created tensors:")
        print(f"  Query embeddings: {query_embeddings.shape}")
        print(f"  Initial embeddings: {initial_embeddings.shape}")
        print(f"  True embeddings: {true_embeddings.shape}")
        print(f"  Correct arms: {len(query_correct_arms)} lists")

        self.tool_texts = [
            tool[self.tool_text_field] for tool in filtered_tools
        ]
        self.query_texts = [
            query[self.query_text_field] for query in valid_queries
        ]

        print(f"  Stored internally: {len(self.query_texts)} query texts.")
        print(f"  Stored internally: {len(self.tool_texts)} tool texts.")

        # Cache the results
        try:
            cache_data = {
                "query_embeddings": query_embeddings.cpu(),
                "initial_embeddings": initial_embeddings.cpu(),
                "true_embeddings": true_embeddings.cpu(),
                "query_correct_arms": query_correct_arms,
                "query_texts": self.query_texts,
                "tool_texts": self.tool_texts,
            }
            torch.save(cache_data, cache_path)
            print(f"Cached prepared data to {cache_path}")
        except Exception as e:
            print(f"Warning: Failed to save cache ({e})")

        return (
            query_embeddings,
            initial_embeddings,
            true_embeddings,
            query_correct_arms,
        )

    def get_texts_by_indices(
        self,
        *,
        query_indices: Optional[List[int]] = None,
        arm_indices: Optional[List[int]] = None,
    ) -> Dict[str, List[str]]:
        """
        Retrieves the raw text for given query or arm indices.

        Args:
            query_indices: A list of query indices.
            arm_indices: A list of arm indices.

        Returns:
            A dictionary with 'queries' and/or 'arms' as keys and lists
            of corresponding texts as values.
        """
        if self.query_texts is None or self.tool_texts is None:
            raise RuntimeError("Data not loaded. Call `load_data` first.")

        results = {}
        if query_indices is not None:
            results["queries"] = [self.query_texts[i] for i in query_indices]
        if arm_indices is not None:
            results["arms"] = [self.tool_texts[i] for i in arm_indices]

        return results


# Convenience functions for specific datasets
def create_toolret_data_loader(
    tools_dataset_path: str = "embeddings/toolret_tools_embedded",
    queries_dataset_path: str = "embeddings/toolret_queries_embedded",
    device: torch.device = None,
    max_queries: Optional[int] = None,
    subset: str = "code",
    tool_text_field: str = "documentation",
    query_text_field: str = "query",
) -> UniversalBanditDataLoader:
    """Create a data loader configured for ToolRet dataset."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return UniversalBanditDataLoader(
        tools_dataset_path=tools_dataset_path,
        queries_dataset_path=queries_dataset_path,
        query_arm_mapper=ToolRetQueryArmMapper(),
        device=device,
        max_queries=max_queries,
        subset=subset,
        tool_text_field=tool_text_field,
        query_text_field=query_text_field,
    )


def create_ultratool_data_loader(
    tools_dataset_path: str = "embeddings/ultratool_tools_embedded",
    queries_dataset_path: str = "embeddings/ultratool_queries_embedded",
    device: torch.device = None,
    max_queries: Optional[int] = None,
    tool_text_field: str = "text_representation",
    query_text_field: str = "prompt",
) -> UniversalBanditDataLoader:
    """Create a data loader configured for UltraTool dataset."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return UniversalBanditDataLoader(
        tools_dataset_path=tools_dataset_path,
        queries_dataset_path=queries_dataset_path,
        query_arm_mapper=UltraToolQueryArmMapper(),
        device=device,
        max_queries=max_queries,
        subset=None,  # UltraTool doesn't use subsets
        tool_text_field=tool_text_field,
        query_text_field=query_text_field,
    )


def create_nfcorpus_data_loader(
    tools_dataset_path: str = "embeddings/nfcorpus_tools_embedded",
    queries_dataset_path: str = "embeddings/nfcorpus_queries_embedded",
    device: torch.device = None,
    max_queries: Optional[int] = None,
    tool_text_field: str = "text",
    query_text_field: str = "query",
) -> UniversalBanditDataLoader:
    """Create a data loader configured for NfCorpus dataset."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return UniversalBanditDataLoader(
        tools_dataset_path=tools_dataset_path,
        queries_dataset_path=queries_dataset_path,
        query_arm_mapper=NfCorpusQueryArmMapper(),
        device=device,
        max_queries=max_queries,
        subset=None,  # NfCorpus doesn't use subsets
        tool_text_field=tool_text_field,
        query_text_field=query_text_field,
    )


def create_arguana_data_loader(
    tools_dataset_path: str = "embeddings/arguana_tools_embedded",
    queries_dataset_path: str = "embeddings/arguana_queries_embedded",
    device: torch.device = None,
    max_queries: Optional[int] = None,
    tool_text_field: str = "text",
    query_text_field: str = "text",
) -> UniversalBanditDataLoader:
    """Create a data loader configured for Arguana dataset."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return UniversalBanditDataLoader(
        tools_dataset_path=tools_dataset_path,
        queries_dataset_path=queries_dataset_path,
        query_arm_mapper=ArguanaQueryArmMapper(),
        device=device,
        max_queries=max_queries,
        subset=None,  # NfCorpus doesn't use subsets
        tool_text_field=tool_text_field,
        query_text_field=query_text_field,
    )


def create_fiqa_data_loader(
    tools_dataset_path: str = "embeddings/fiqa_tools_embedded",
    queries_dataset_path: str = "embeddings/fiqa_queries_embedded",
    device: torch.device = None,
    max_queries: Optional[int] = None,
    tool_text_field: str = "text",
    query_text_field: str = "text",
) -> UniversalBanditDataLoader:
    """Create a data loader configured for Arguana dataset."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return UniversalBanditDataLoader(
        tools_dataset_path=tools_dataset_path,
        queries_dataset_path=queries_dataset_path,
        query_arm_mapper=FiqaQueryArmMapper(),
        device=device,
        max_queries=max_queries,
        subset=None,  # NfCorpus doesn't use subsets
        tool_text_field=tool_text_field,
        query_text_field=query_text_field,
    )


def create_multihop_data_loader(
    tools_dataset_path: str = "embeddings/mhr_tools_embedded",
    queries_dataset_path: str = "embeddings/mhr_decomposed_embedded",
    device: torch.device = None,
    max_queries: Optional[int] = None,
    tool_text_field: str = "body",
    query_text_field: str = "formatted_query",
) -> UniversalBanditDataLoader:
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return UniversalBanditDataLoader(
        tools_dataset_path=tools_dataset_path,
        queries_dataset_path=queries_dataset_path,
        query_arm_mapper=MultihopQueryArmMapper(),
        device=device,
        max_queries=max_queries,
        subset=None,
        require_sort=True,
        tool_text_field=tool_text_field,
        query_text_field=query_text_field,
    )
