#!/usr/bin/env python3
"""
BP Online Oracle implementations for computing lower bounds and exact solutions.
"""

import numpy as np
import os
import time
from typing import Dict, List, Tuple, Optional


class BPOnlineOracle:
    """
    BP Online oracle for computing lower bounds and exact solutions.
    
    Args:
        config: BPOnlineConfig or any config object with oracle_type, oracle_timeout/timeout_seconds attributes
    """
    
    def __init__(self, config):
        self.config = config
        self._ortools_available = None
        oracle_type = getattr(config, 'oracle_type', 'lb')
        if oracle_type == "cp-sat":
            try:
                from ortools.sat.python import cp_model
                self._ortools_available = True
            except ImportError:
                self._ortools_available = False
    
    def martello_toth_bound(self, items: np.ndarray, capacity: int) -> float:
        """
        Computes the Martello & Toth (1990) lower bound on the number of bins required.
        
        Args:
            items: Array of item sizes
            capacity: Bin capacity
            
        Returns:
            Lower bound on number of bins required
        """
        if len(items) == 0:
            return 0.0
        
        items_array = np.asarray(items, dtype=float)
        
        # L1 bound: volume lower bound
        total_volume = np.sum(items_array)
        l1_bound = np.ceil(total_volume / capacity)
        
        # L2 bound: large item lower bound
        large_items = items_array > (capacity / 2.0)
        l2_bound = np.sum(large_items)
        
        # theoretical lower bound = max(L1, L2)
        lb = max(l1_bound, l2_bound)
        
        return float(lb)

    def calculate_ffd_bins(self, items: np.ndarray, capacity: float) -> int:
        """
        use First Fit Decreasing (FFD) algorithm to quickly calculate the装箱上界。
        Args:
            items: Array of item sizes
            capacity: Bin capacity
            
        Returns:
            number of bins needed by FFD algorithm (as an upper bound)
        """
        if len(items) == 0:
            return 0
        
        items_array = np.asarray(items, dtype=float)
        capacity_float = float(capacity)
        
        # sort by size in descending order
        sorted_items = np.sort(items_array)[::-1]
        
        # initialize bins list, each bin records the current used capacity
        bins = []
        
        for item_size in sorted_items:
            placed = False
            for i in range(len(bins)):
                if bins[i] + item_size <= capacity_float:
                    bins[i] += item_size
                    placed = True
                    break
            
            if not placed:
                bins.append(item_size)
        
        return len(bins)

    def solve_with_cp_sat(self, items: np.ndarray, capacity: float) -> Optional[float]:
        """Solve Bin Packing using OR-Tools CP-SAT."""
        if not self._ortools_available:
            return None
        
        try:
            from ortools.sat.python import cp_model
        except ImportError:
            return None
        
        n_items = len(items)
        if n_items == 0:
            return 0.0
        
        items_array = np.asarray(items, dtype=float)
        if np.any(items_array > capacity):
            lb = self.martello_toth_bound(items, capacity)
            return (float(lb), "LB_INFEASIBLE")
        
        timeout_seconds = getattr(self.config, 'oracle_timeout', 
                                 getattr(self.config, 'timeout_seconds', 60))
        max_items = 50 if timeout_seconds <= 30 else 100
        if n_items > max_items:
            lb = self.martello_toth_bound(items, capacity)
            debug_mode = getattr(self.config, 'debug_mode', False)
            if debug_mode:
                print(f"      [Oracle] CP-SAT skipped: n_items={n_items} > max_items={max_items}, using LB={lb:.0f}", flush=True)
            return (float(lb), "LB_SIZE_LIMIT")
        
        int_items = [int(items[i]) for i in range(n_items)]
        int_cap = int(capacity)

        lb = int(self.martello_toth_bound(items, capacity))
        ub = self.calculate_ffd_bins(items, capacity)
        ub = max(ub, lb)

        try:
            model = cp_model.CpModel()
            x = [
                [model.NewBoolVar(f"x_{i}_{j}") for j in range(ub)]
                for i in range(n_items)
            ]
            y = [model.NewBoolVar(f"y_{j}") for j in range(ub)]

            # minimize the number of bins used
            model.Minimize(sum(y))
            for i in range(n_items):
                model.Add(sum(x[i][j] for j in range(ub)) == 1)

            for j in range(ub):
                model.Add(
                    sum(int_items[i] * x[i][j] for i in range(n_items))
                    <= int_cap * y[j]
                )
            for j in range(ub - 1):
                model.Add(y[j + 1] <= y[j])

            model.Add(sum(y) >= lb)

            solver = cp_model.CpSolver()
            timeout_seconds = getattr(self.config, 'oracle_timeout', 
                                     getattr(self.config, 'timeout_seconds', 60))
            solver.parameters.max_time_in_seconds = float(timeout_seconds)
            solver.parameters.num_search_workers = max(1, min(4, os.cpu_count() or 1))  
            solver.parameters.log_search_progress = False
            solver.parameters.search_branching = cp_model.PORTFOLIO_SEARCH
            solver.parameters.linearization_level = 2

            solve_start = time.time()
            status = solver.Solve(model)
            solve_time = time.time() - solve_start

            if status == cp_model.OPTIMAL:
                val = int(solver.ObjectiveValue())
                return (float(val), "OPTIMAL")
            elif status == cp_model.FEASIBLE:
                val = int(solver.ObjectiveValue())
                if solve_time >= timeout_seconds - 1.0:
                    return (float(val), "FEASIBLE:TIMEOUT")
                return (float(val), "FEASIBLE")
            else:
                return (float(lb), "LB_FALLBACK")
        except Exception as e:
            return (float(lb), "LB_FALLBACK")

    
    def solve_exact(self, instance: Dict) -> float:
        """
        Compute exact solution or lower bound for a BP Online instance.
        
        Args:
            instance: Dict with keys 'items' (np.ndarray), 'capacity' (int), 'num_items' (int)
            
        Returns:
            Optimal number of bins (if oracle_type="cp-sat") or lower bound
        """
        items = np.asarray(instance['items'], dtype=float)
        capacity = float(instance['capacity'])
        
        oracle_type = getattr(self.config, 'oracle_type', 'lb')
        if oracle_type == "cp-sat":
            if not self._ortools_available:
                lb_value = self.martello_toth_bound(items, capacity)
                debug_mode = getattr(self.config, 'debug_mode', False)
                if debug_mode:
                    print(f"      [Oracle] CP-SAT not available (OR-Tools not installed), using LB={lb_value:.0f}", flush=True)
                return (lb_value, "LB_NO_ORTOOLS")
            result = self.solve_with_cp_sat(items, capacity)
            if result is not None:
                if isinstance(result, tuple):
                    return result
                else:
                    return (result, "OPTIMAL") 
            else:
                lb_value = self.martello_toth_bound(items, capacity)
                debug_mode = getattr(self.config, 'debug_mode', False)
                if debug_mode:
                    print(f"      [Oracle] CP-SAT returned None (likely size limit), using LB={lb_value:.0f}", flush=True)
                return (lb_value, "LB_CP_SAT_NONE")
        
        lb_value = self.martello_toth_bound(items, capacity)
        return (lb_value, "LB")
    
    def compute_gap(self, num_bins: float, lb: float) -> float:
        """
        Compute gap percentage.
        
        Args:
            num_bins: Number of bins used by solver
            lb: Lower bound (optimal or theoretical)
            
        Returns:
            Gap percentage: (num_bins / lb - 1) * 100
        """
        if lb <= 0:
            return float('inf')
        return (num_bins / lb - 1.0) * 100.0


# Factory function for easy oracle creation
def create_bp_online_oracle(config=None, oracle_type: str = None, **kwargs) -> BPOnlineOracle:
    """
    Create a BP Online oracle instance.
    
    Args:
        config: BPOnlineConfig or HeuPSROConfig object (preferred). If provided, 
                oracle_type and other kwargs will override config values.
        oracle_type: Oracle type string (used if config is None or to override config)
        **kwargs: Additional config parameters (used if config is None or to override config)
    
    Returns:
        BPOnlineOracle instance
    """
    if config is not None:
        if oracle_type is not None or kwargs:
            from .config import BPOnlineConfig
            config_dict = {}
            if hasattr(config, '__dict__'):
                config_dict = config.__dict__.copy()
            else:
                for attr in ['oracle_type', 'oracle_timeout', 'timeout_seconds']:
                    if hasattr(config, attr):
                        config_dict[attr] = getattr(config, attr)
            
            if oracle_type is not None:
                config_dict['oracle_type'] = oracle_type
            config_dict.update(kwargs)
            
            temp_config = BPOnlineConfig(**config_dict)
            return BPOnlineOracle(temp_config)
        else:
            return BPOnlineOracle(config)
    else:
        from .config import BPOnlineConfig
        oracle_type = oracle_type or kwargs.get('oracle_type', 'lb')
        temp_config = BPOnlineConfig(oracle_type=oracle_type, **kwargs)
        return BPOnlineOracle(temp_config)

