"""
Base Data Loader

Abstract base class for all data loaders in the SP-UCB-OLP framework.
"""

import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple, List


class BaseDataLoader(ABC):
    """
    Abstract base class for data loaders.

    All data loaders must implement:
    - get_arrival(theta, t): Get (reward, consumption) for config theta at time t
    - get_budget(rho): Get budget vector scaled by rho

    The loader pre-generates all arrivals for reproducibility.

    Parameters
    ----------
    K : int
        Number of configurations
    d : int
        Number of resource dimensions
    T : int
        Time horizon
    seed : int
        Random seed for reproducibility
    """

    def __init__(
        self,
        K: int,
        d: int,
        T: int,
        seed: int = 42
    ):
        self.K = K
        self.d = d
        self.T = T
        self.seed = seed

        # To be set by subclasses
        self._arrivals: Dict[int, np.ndarray] = {}  # theta -> (T, d+1) array
        self._nominal_budget: np.ndarray = np.ones(d)

    @abstractmethod
    def get_arrival(self, theta: int, t: int) -> Tuple[float, np.ndarray]:
        """
        Get arrival at timestep t under configuration theta.

        Parameters
        ----------
        theta : int
            Configuration index in [0, K-1]
        t : int
            Timestep in [0, T-1]

        Returns
        -------
        r : float
            Reward value
        a : np.ndarray
            Resource consumption vector (d,)
        """
        pass

    @abstractmethod
    def get_budget(self, rho: float) -> np.ndarray:
        """
        Get total budget vector scaled by rho.

        Parameters
        ----------
        rho : float
            Budget scaling factor (0.5 = tight, 1.0 = nominal, 1.5 = loose)

        Returns
        -------
        B : np.ndarray
            Total budget vector (d,)
        """
        pass

    def get_per_period_budget(self, rho: float) -> np.ndarray:
        """Get per-period budget b = B / T."""
        return self.get_budget(rho) / self.T

    def get_all_samples(self, theta: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Get all samples for a configuration.

        Returns
        -------
        rewards : np.ndarray
            All rewards (T,)
        consumptions : np.ndarray
            All consumptions (T, d)
        """
        arrivals = self._arrivals[theta]
        return arrivals[:, 0], arrivals[:, 1:]

    def get_samples_dict(self) -> Dict[int, Tuple[np.ndarray, np.ndarray]]:
        """Get samples dictionary suitable for oracle computation."""
        return {
            theta: self.get_all_samples(theta)
            for theta in range(self.K)
        }

    def get_metadata(self) -> Dict[str, Any]:
        """Get metadata about the dataset."""
        return {
            'K': self.K,
            'd': self.d,
            'T': self.T,
            'seed': self.seed,
        }

    def validate(self) -> bool:
        """Validate that data is properly generated."""
        try:
            for theta in range(self.K):
                for t in range(min(10, self.T)):
                    r, a = self.get_arrival(theta, t)
                    assert isinstance(r, (int, float))
                    assert len(a) == self.d
                    assert np.all(a >= 0)
            return True
        except Exception:
            return False
