"""
Alibaba Cluster Trace Data Loader

Real-world data loader based on Alibaba Cluster Trace 2018 characteristics.
Provides calibrated synthetic arrivals for ML serving scenarios.

Regime Configuration (K=3):
- Regime 0: 8-bit Quantization (high quality, high memory usage)
- Regime 1: 4-bit Quantization (lower quality, 50% memory savings)
- Regime 2: Batching (efficiency gains from batched execution)

Resources (d=2):
- Resource 0: CPU
- Resource 1: Memory

The data is calibrated to match Alibaba 2018 trace statistics:
- Lognormal memory distribution (high variance)
- Exponential task durations
- Heterogeneous task mix
"""

import numpy as np
from typing import Dict, Any, Tuple, Optional
from scipy.stats import lognorm, expon, norm
from .base_loader import BaseDataLoader


# Scenario profiles for different experimental settings
SCENARIO_PROFILES = {
    'quant_8bit': {
        'label': 'Quant-8bit',
        'description': 'Moderate budgets with emphasis on high-quality 8-bit inference',
        'budget_scale': np.array([0.9, 0.9]),
        'regime_multipliers': {
            0: {'reward': 1.0, 'cpu': 1.0, 'mem': 1.0},      # 8-bit baseline
            1: {'reward': 0.72, 'cpu': 1.0, 'mem': 0.55},    # 4-bit aggressive
            2: {'reward': 1.1, 'cpu': 0.85, 'mem': 0.92},    # Batching
        },
    },
    'quant_4bit': {
        'label': 'Quant-4bit',
        'description': 'Tight memory budget favouring aggressive quantization',
        'budget_scale': np.array([0.8, 0.55]),
        'regime_multipliers': {
            0: {'reward': 1.0, 'cpu': 1.05, 'mem': 1.0},
            1: {'reward': 0.7, 'cpu': 0.95, 'mem': 0.45},
            2: {'reward': 1.05, 'cpu': 0.8, 'mem': 0.8},
        },
    },
    'batching': {
        'label': 'Batching',
        'description': 'Looser budgets highlighting batching policies',
        'budget_scale': np.array([1.1, 1.0]),
        'regime_multipliers': {
            0: {'reward': 1.05, 'cpu': 0.95, 'mem': 0.95},
            1: {'reward': 0.75, 'cpu': 0.9, 'mem': 0.5},
            2: {'reward': 1.25, 'cpu': 0.7, 'mem': 0.75},
        },
    },
}


class AlibabaDataLoader(BaseDataLoader):
    """
    Data loader for Alibaba cluster trace experiments.

    Provides arrival sequences for three regimes representing
    different ML serving configurations:

    - Regime 0 (8-bit): Standard precision with high quality, high memory
    - Regime 1 (4-bit): Reduced precision with ~30% quality loss, 50% memory savings
    - Regime 2 (Batching): Batched execution with efficiency gains

    Parameters
    ----------
    T : int
        Time horizon (default: 10000)
    seed : int
        Random seed for reproducibility (default: 42)
    scenario : str
        Scenario profile: 'quant_8bit', 'quant_4bit', or 'batching' (default: 'quant_8bit')
    base_cpu_budget : float
        Base CPU budget before scaling (default: 600.0)
    base_mem_budget : float
        Base memory budget before scaling (default: 600.0)

    Examples
    --------
    >>> loader = AlibabaDataLoader(T=10000, seed=42)
    >>> r, a = loader.get_arrival(theta=0, t=0)
    >>> B = loader.get_budget(rho=1.0)
    """

    def __init__(
        self,
        T: int = 10000,
        seed: int = 42,
        scenario: str = 'quant_8bit',
        base_cpu_budget: float = 600.0,
        base_mem_budget: float = 600.0
    ):
        # Fixed dimensions for Alibaba
        K = 3  # 8-bit, 4-bit, batching
        d = 2  # CPU, Memory

        super().__init__(K, d, T, seed)

        # Validate scenario
        if scenario not in SCENARIO_PROFILES:
            raise ValueError(
                f"Unknown scenario '{scenario}'. "
                f"Supported: {list(SCENARIO_PROFILES.keys())}"
            )

        self.scenario = scenario
        self.scenario_profile = SCENARIO_PROFILES[scenario]
        self.base_cpu_budget = base_cpu_budget
        self.base_mem_budget = base_mem_budget

        # Initialize RNG
        self._rng = np.random.RandomState(seed)

        # Initialize distributions (calibrated to Alibaba characteristics)
        self._mem_dist = lognorm(s=1.8, scale=np.exp(-1))
        self._duration_dist = expon(scale=300)
        self._quality_dist = norm(loc=1.0, scale=0.1)

        # Pre-generate base arrivals
        self._generate_base_arrivals()

        # Generate transformed arrivals for each regime
        self._generate_arrivals()

        # Compute nominal budget
        self._compute_nominal_budget()

    def _generate_base_arrivals(self):
        """Generate base arrivals (regime-agnostic)."""
        self._base_arrivals = []

        for _ in range(self.T):
            # Sample from calibrated distributions
            mem = max(0.01, self._mem_dist.rvs(random_state=self._rng))
            duration = max(0.01, self._duration_dist.rvs(random_state=self._rng))
            quality = max(0.01, self._quality_dist.rvs(random_state=self._rng))

            self._base_arrivals.append({
                'reward': duration * quality,
                'cpu': 0.5,  # Standard CPU requirement
                'mem': mem,
            })

    def _generate_arrivals(self):
        """Generate arrivals for all regimes by transforming base arrivals."""
        regime_multipliers = self.scenario_profile['regime_multipliers']

        for theta in range(self.K):
            arrivals = np.zeros((self.T, self.d + 1))
            mult = regime_multipliers[theta]

            for t in range(self.T):
                base = self._base_arrivals[t]

                # Apply regime-specific transformations
                if theta == 2:  # Batching regime has efficiency scaling
                    batch_size = max(1, self._rng.poisson(3))
                    efficiency = np.log1p(batch_size) / np.log(2 + 1e-9)
                    r = base['reward'] * mult['reward'] * efficiency
                else:
                    r = base['reward'] * mult['reward']

                a_cpu = base['cpu'] * mult['cpu']
                a_mem = base['mem'] * mult['mem']

                arrivals[t, 0] = max(0.01, r)
                arrivals[t, 1] = max(0.01, a_cpu)
                arrivals[t, 2] = max(0.01, a_mem)

            self._arrivals[theta] = arrivals

    def _compute_nominal_budget(self):
        """Compute nominal budget based on average consumption."""
        # Average consumption across all regimes
        total_cpu = 0.0
        total_mem = 0.0

        for theta in range(self.K):
            total_cpu += np.mean(self._arrivals[theta][:, 1])
            total_mem += np.mean(self._arrivals[theta][:, 2])

        avg_cpu = total_cpu / self.K
        avg_mem = total_mem / self.K

        # Scale by scenario profile
        scale = self.scenario_profile['budget_scale']
        self._nominal_budget = np.array([
            self.base_cpu_budget * scale[0],
            self.base_mem_budget * scale[1]
        ])

    def get_arrival(self, theta: int, t: int) -> Tuple[float, np.ndarray]:
        """Get arrival at timestep t under configuration theta."""
        if theta < 0 or theta >= self.K:
            raise ValueError(f"theta {theta} out of range [0, {self.K-1}]")
        if t < 0 or t >= self.T:
            raise ValueError(f"t {t} out of range [0, {self.T-1}]")

        arrival = self._arrivals[theta][t]
        return float(arrival[0]), arrival[1:].copy()

    def get_budget(self, rho: float) -> np.ndarray:
        """Get budget scaled by rho."""
        return rho * self._nominal_budget

    def get_regime_name(self, theta: int) -> str:
        """Get human-readable regime name."""
        names = {0: "8-bit Quantization", 1: "4-bit Quantization", 2: "Batching"}
        return names.get(theta, f"Regime {theta}")

    def get_resource_name(self, resource: int) -> str:
        """Get human-readable resource name."""
        names = {0: "CPU", 1: "Memory"}
        return names.get(resource, f"Resource {resource}")

    def get_regime_statistics(self, theta: int, n_samples: int = None) -> Dict[str, Any]:
        """
        Compute statistics for a specific regime.

        Parameters
        ----------
        theta : int
            Regime index
        n_samples : int, optional
            Number of samples (default: all)

        Returns
        -------
        stats : Dict[str, Any]
            Statistics including mean reward, mean consumption, efficiency
        """
        n = n_samples if n_samples else self.T
        n = min(n, self.T)

        arrivals = self._arrivals[theta][:n]
        rewards = arrivals[:, 0]
        cpu = arrivals[:, 1]
        mem = arrivals[:, 2]

        return {
            'regime_name': self.get_regime_name(theta),
            'n_samples': n,
            'mean_reward': float(np.mean(rewards)),
            'std_reward': float(np.std(rewards)),
            'mean_cpu': float(np.mean(cpu)),
            'std_cpu': float(np.std(cpu)),
            'mean_memory': float(np.mean(mem)),
            'std_memory': float(np.std(mem)),
            'efficiency': float(np.mean(rewards) / (np.mean(cpu) + np.mean(mem))),
        }

    def get_metadata(self) -> Dict[str, Any]:
        """Get metadata about the Alibaba dataset."""
        base = super().get_metadata()

        return {
            **base,
            'family': 'Alibaba',
            'name': 'Alibaba Cluster Trace 2018 (Calibrated)',
            'scenario': self.scenario,
            'scenario_label': self.scenario_profile['label'],
            'scenario_description': self.scenario_profile['description'],
            'regimes': [self.get_regime_name(k) for k in range(self.K)],
            'resources': [self.get_resource_name(r) for r in range(self.d)],
            'base_budgets': {
                'cpu': self.base_cpu_budget,
                'memory': self.base_mem_budget
            },
            'deterministic': False,
        }


def test_alibaba_loader():
    """Test Alibaba data loader."""
    print("=" * 60)
    print("Testing Alibaba Data Loader")
    print("=" * 60)

    # Test all scenarios
    for scenario in ['quant_8bit', 'quant_4bit', 'batching']:
        print(f"\n--- Scenario: {scenario} ---")

        loader = AlibabaDataLoader(T=1000, seed=42, scenario=scenario)

        print(f"\nMetadata:")
        meta = loader.get_metadata()
        for key in ['family', 'scenario_label', 'K', 'd', 'T']:
            print(f"  {key}: {meta.get(key)}")

        print(f"\nRegime Statistics:")
        for theta in range(loader.K):
            stats = loader.get_regime_statistics(theta)
            print(f"  {stats['regime_name']}:")
            print(f"    Mean reward: {stats['mean_reward']:.3f} +/- {stats['std_reward']:.3f}")
            print(f"    Mean CPU: {stats['mean_cpu']:.3f}, Memory: {stats['mean_memory']:.3f}")
            print(f"    Efficiency: {stats['efficiency']:.3f}")

        print(f"\nBudget (rho=1.0): {loader.get_budget(1.0)}")
        print(f"Validation: {loader.validate()}")

    print("\n" + "=" * 60)


if __name__ == "__main__":
    test_alibaba_loader()
