import torch
from torch import Tensor
import numpy as np
import pandas as pd

from yahpo_gym import local_config
from yahpo_gym import benchmark_set

from botorch.test_functions.base import (
    MultiObjectiveTestProblem,
    ConstrainedBaseTestProblem
)

local_config.init_config()
local_config.set_data_path("rescue/problems/hpo/yahpo_data")
benchmark_set.cfg

class HPOXGBoost(MultiObjectiveTestProblem):
    r"""
    The XGBoost Hyperparameter Optimization problem using YAHPO Gym.

    Objectives (2-D):
        - mmce (to minimize)
        - rammodel (to minimize)
    """

    dim = 13
    num_objectives = 2  # mmce and rammodel 
    design_var_names = [
        'alpha',
        'colsample_bylevel',
        'colsample_bytree',
        'eta',
        'gamma',
        'lambda',
        'max_depth',
        'min_child_weight',
        'nrounds',
        'rate_drop',
        'skip_drop',
        'subsample',
        'trainsize'
    ]
    objective_var_names = [
        'mmce',     # classification error
        'rammodel'  # model memory footprint
    ]
    fidelity_param_name = 'trainsize'

    continuous_inds = list(range(13))
    discrete_inds = []
    categorical_inds = []
     
    _bounds = [
        (0.0005, 1.0),    # alpha
        (0.02, 1.0),      # colsample_bylevel
        (0.01, 1.0),      # colsample_bytree
        (0.0005, 1.0),    # eta
        (0.0005, 1.0),    # gamma
        (0.0005, 1.0),    # lambda
        (1, 15),          # max_depth
        (2.72, 149.0),    # min_child_weight
        (3, 2000),        # nrounds
        (0.0, 1.0),       # rate_drop
        (0.0, 1.0),       # skip_drop
        (0.1, 1.0),       # subsample
        (0.03, 1.0),      # trainsize (fidelity parameter)
    ]
    _ref_point = [0.56406915, 24.8914]
    # Approximated using NSGA-II with with 
    #   population_size=250
    #   max_gen=100
    #   seeds=random.sample(range(1000), 100)
    # For seed 142
    _max_hv = 12.606743973304742

    def __init__(
        self,
        noise_std: None | float | list[float] = None,
        negate: bool = False,
        dtype: torch.dtype = torch.double,
    ) -> None:
        r"""
        Args:
            noise_std: Standard deviation of the observation noise.
            negate: If True, negate the objectives.
            dtype: The dtype that is used for the bounds of the function.
        """
        super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)

    def _evaluate_true(self, X) -> Tensor:
        return self._XGBoost_HPO(X)

    @staticmethod
    def _XGBoost_HPO(X: Tensor) -> Tensor:
        """
        The XGBoost Hyperparameter Optimization problem using YAHPO Gym.

        Args:
            X: Tensor of shape (..., n_features) with hyperparameter configurations.
            Supports arbitrary batch dimensions.

        Returns: 
            Tensor: shape (..., n_objectives) with objective values
                    - Objective 1: mmce (minimization)
                    - Objective 2: rammodel (minimization)
        """
        # Select a Benchmark
        bench = benchmark_set.BenchmarkSet("iaml_xgboost")
        bench.set_instance("41146")
        
        # Store original shape and flatten batch dimensions
        original_shape = X.shape[:-1]  # All dimensions except last (features)
        X_flat = X.reshape(-1, X.shape[-1])  # Flatten to (n_samples, n_features)
        
        # Convert tensor to numpy for YAHPO
        X_np = X_flat.detach().cpu().numpy().astype(np.float64)
        
        # Parameter names (excluding booster and task_id)
        param_names = [
            'alpha', 
            'colsample_bylevel', 
            'colsample_bytree', 
            'eta', 
            'gamma',
            'lambda', 
            'max_depth', 
            'min_child_weight', 
            'nrounds', 
            'rate_drop',
            'skip_drop', 
            'subsample', 
            'trainsize'
        ]
        
        # Create configuration dict for YAHPO
        configs = []
        for i in range(X_np.shape[0]):
            config = {
                param_names[j]: float(X_np[i, j]) for j in range(len(param_names))
            }
            # Convert integers
            config['max_depth'] = int(config['max_depth'])
            config['nrounds'] = int(config['nrounds'])
            config['booster'] = 'dart'  # use dart to enable rate_drop and skip_drop
            config['task_id'] = '41146'   # fixed instance
            
            configs.append(config)
        
        # Evaluate using YAHPO bench
        results = bench.objective_function(configs)

        # Extract objectives (mmce and rammodel)
        results = pd.DataFrame(results)
        mmce = results['mmce'].values
        rammodel = results['rammodel'].values

        objectives = np.column_stack([mmce, rammodel])

        # Convert back to tensor and reshape to original batch dimensions
        objectives_tensor = torch.tensor(
            objectives, dtype=X.dtype, device=X.device)
        # (..., n_objectives)        
        return objectives_tensor.reshape(*original_shape, -1)   
    

class HPOXGBoostConstrained(
    MultiObjectiveTestProblem,
    ConstrainedBaseTestProblem
):
    r"""
    The XGBoost Hyperparameter Optimization problem using YAHPO Gym.

    Objectives (2-D):
        - f1 (to maximize)
        - timetrain (to minimize)
    
    Constraints:
        - ramtrain <= 10
    """

    dim = 13
    num_objectives = 2
    num_constraints = 1

    design_var_names = [
        'alpha',
        'colsample_bylevel',
        'colsample_bytree',
        'eta',
        'gamma',
        'lambda',
        'max_depth',
        'min_child_weight',
        'nrounds',
        'rate_drop',
        'skip_drop',
        'subsample',
        'trainsize'
    ]
    objective_var_names = [
        'f1',
        'timetrain'
    ]
    constraint_var_names = ['ramtrain']
    fidelity_param_name = 'trainsize'

    continuous_inds = list(range(13))
    discrete_inds = []
    categorical_inds = []
     
    _bounds = [
        (0.0005, 1.0),    # alpha
        (0.02, 1.0),      # colsample_bylevel
        (0.01, 1.0),      # colsample_bytree
        (0.0005, 1.0),    # eta
        (0.0005, 1.0),    # gamma
        (0.0005, 1.0),    # lambda
        (1, 15),          # max_depth
        (2.72, 149.0),    # min_child_weight
        (3, 2000),        # nrounds
        (0.0, 1.0),       # rate_drop
        (0.0, 1.0),       # skip_drop
        (0.1, 1.0),       # subsample
        (0.03, 1.0),      # trainsize (fidelity parameter)
    ]
    _ref_point = [-4.6e-7, 479.0]
    _max_hv = 449.87090643479456  # Approximated using NSGA-II, seed:506

    def __init__(
        self,
        noise_std: None | float | list[float] = None,
        negate: bool = False,
        dtype: torch.dtype = torch.double,
    ) -> None:
        r"""
        Args:
            noise_std: Standard deviation of the observation noise.
            negate: If True, negate the objectives.
            dtype: The dtype that is used for the bounds of the function.
        """
        super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
        self._cache_ramtrain = None

    def _evaluate_true(self, X) -> Tensor:
        return self._XGBoost_HPO(X)
    
    def _evaluate_slack_true(self, X: Tensor) -> Tensor:
        """
        Evaluate the constraint slack for input tensor X.
        Uses cached ramtrain from _XGBoost_HPO to avoid re-running the benchmark.
        
        Args:
            X: A tensor of shape (..., n_features) with hyperparameter configurations.
            
        Returns:
            A tensor of shape (..., 1) containing the constraint slack values.
            Constraint: ramtrain <= 10 (training RAM should be less than 10 MB)
        """
        if self._cache_ramtrain is None:
            # If not cached, evaluate to get the constraint values
            _ = self._XGBoost_HPO(X)
        
        # Convert cached numpy array to tensor
        ramtrain_tensor = torch.tensor(
            self._cache_ramtrain, 
            dtype=X.dtype, 
            device=X.device
        )
        
        # Constraint slack: negative values indicate constraint violation
        # We want ramtrain <= 10, so slack = 10 - ramtrain
        slack = 15.0 - ramtrain_tensor
        
        # Return with shape (..., 1)
        return slack.unsqueeze(-1)

    def _XGBoost_HPO(self, X: Tensor) -> Tensor:
        """
        The XGBoost Hyperparameter Optimization problem using YAHPO Gym.

        Args:
            X: Tensor of shape (..., n_features) with hyperparameter configurations.
            Supports arbitrary batch dimensions.

        Returns: 
            Tensor: shape (..., n_objectives) with objective values
                    - Objective 1: f1 (maximization)
                    - Objective 2: timetrain (minimization)
        """
        # Select a Benchmark
        bench = benchmark_set.BenchmarkSet("iaml_xgboost")
        bench.set_instance("41146")
        
        # Store original shape and flatten batch dimensions
        original_shape = X.shape[:-1]  # All dimensions except last (features)
        X_flat = X.reshape(-1, X.shape[-1])  # Flatten to (n_samples, n_features)
        
        # Convert tensor to numpy for YAHPO
        X_np = X_flat.detach().cpu().numpy().astype(np.float64)
        
        # Parameter names (excluding booster and task_id)
        param_names = [
            'alpha', 
            'colsample_bylevel', 
            'colsample_bytree', 
            'eta', 
            'gamma',
            'lambda', 
            'max_depth', 
            'min_child_weight', 
            'nrounds', 
            'rate_drop',
            'skip_drop', 
            'subsample', 
            'trainsize'
        ]
        
        # Create configuration dict for YAHPO
        configs = []
        for i in range(X_np.shape[0]):
            config = {
                param_names[j]: float(X_np[i, j]) for j in range(len(param_names))
            }
            # Convert integers
            config['max_depth'] = int(config['max_depth'])
            config['nrounds'] = int(config['nrounds'])
            config['booster'] = 'dart'  # use dart to enable rate_drop and skip_drop
            config['task_id'] = '41146'   # fixed instance
            
            configs.append(config)
        
        # Evaluate using YAHPO bench
        results = bench.objective_function(configs)

        # Extract objectives (f1 and timetrain) and constraint (ramtrain)
        results = pd.DataFrame(results)
        f1 = results['f1'].values
        timetrain = results['timetrain'].values
        self._cache_ramtrain = results['ramtrain'].values

        objectives = np.column_stack([-f1, timetrain])

        # Convert back to tensor and reshape to original batch dimensions
        objectives_tensor = torch.tensor(
            objectives, dtype=X.dtype, device=X.device)
        # (..., n_objectives)        
        return objectives_tensor.reshape(*original_shape, -1)


class HPORanger(MultiObjectiveTestProblem):
    r"""
    The Ranger Hyperparameter Optimization problem using YAHPO Gym.

    Objectives (3-D):
        - mmce (to minimize)
        - nf (to minimize)
        - ias (to minimize)
    """

    dim = 5
    num_objectives = 3
    design_var_names = [
        'min.node.size',
        'mtry.ratio',
        'num.trees',
        'sample.fraction',
        'trainsize'
    ]
    objective_var_names = [
        'mmce',     # classification error
        'nf',       # number of features used
        'ias',      # interaction strength of features
    ]
    fidelity_param_name = 'trainsize'

    continuous_inds = list(range(5))
    discrete_inds = []
    categorical_inds = []
     
    _bounds = [
        (1, 100),    # min.node.size
        (0.0, 1.0),  # mtry.ratio
        (1, 2000),   # num.trees
        (0.1, 1.0),  # sample.fraction
        (0.03, 1.0), # trainsize (fidelity parameter)
    ]
    _ref_point = [0.3669216, 6.0, 4.0606165]
    # Approximated using NSGA-II with with 
    #   population_size=250
    #   max_gen=100
    #   seeds=random.sample(range(1000), 100)
    # For seed 414
    _max_hv = 2.4548258632628155

    def __init__(
        self,
        noise_std: None | float | list[float] = None,
        negate: bool = False,
        dtype: torch.dtype = torch.double,
    ) -> None:
        r"""
        Args:
            noise_std: Standard deviation of the observation noise.
            negate: If True, negate the objectives.
            dtype: The dtype that is used for the bounds of the function.
        """
        super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)

    def _evaluate_true(self, X) -> Tensor:
        return self._Ranger_HPO(X)

    @staticmethod
    def _Ranger_HPO(X: Tensor) -> Tensor:
        """
        The Ranger algorithm (fast implementation of Random Forests) Hyperparameter 
        Optimization problem using YAHPO Gym.

        Args:
            X: Tensor of shape (..., n_features) with hyperparameter configurations.
            Supports arbitrary batch dimensions.

        Returns: 
            Tensor: shape (..., n_objectives) with objective values
                    - Objective 1: mmce (minimization)
                    - Objective 2: nf (minimization)
                    - Objective 3: ias (minimization)
        """
        # Select a Benchmark
        bench = benchmark_set.BenchmarkSet("iaml_ranger")
        bench.set_instance("1489")
        
        # Store original shape and flatten batch dimensions
        original_shape = X.shape[:-1]  # All dimensions except last (features)
        X_flat = X.reshape(-1, X.shape[-1])  # Flatten to (n_samples, n_features)
        
        # Convert tensor to numpy for YAHPO
        X_np = X_flat.detach().cpu().numpy().astype(np.float64)
        
        # Parameter names (excluding booster and task_id)
        param_names = [
            'min.node.size',
            'mtry.ratio',
            'num.trees',
            'sample.fraction',
            'trainsize'
        ]
        
        # Create configuration dict for YAHPO
        configs = []
        for i in range(X_np.shape[0]):
            config = {
                param_names[j]: float(X_np[i, j]) for j in range(len(param_names))
            }
            # Convert integers
            config['min.node.size'] = int(config['min.node.size'])
            config['num.trees'] = int(config['num.trees'])
            config['replace'] = 'TRUE'  
            config['respect.unordered.factors'] = 'ignore'
            config['splitrule'] = 'gini'
            config['task_id'] = '1489'   # fixed instance
            
            configs.append(config)
        
        # Evaluate using YAHPO bench
        results = bench.objective_function(configs)

        # Extract objectives (mmce and rammodel)
        results = pd.DataFrame(results)
        mmce = results['mmce'].values
        nf = results['nf'].values
        ias = results['ias'].values

        objectives = np.column_stack([mmce, nf, ias])

        # Convert back to tensor and reshape to original batch dimensions
        objectives_tensor = torch.tensor(
            objectives, dtype=X.dtype, device=X.device)
        # (..., n_objectives)        
        return objectives_tensor.reshape(*original_shape, -1)  
    

## --------------- Constrained Versions --------------- ##

class HPORangerConstrained(
    MultiObjectiveTestProblem,
    ConstrainedBaseTestProblem
):
    r"""
    The Ranger Hyperparameter Optimization problem using YAHPO Gym.

    Objectives (2-D):
        - auc (to maximize)
        - logloss (to minimize)

    """

    dim = 5
    num_objectives = 2
    num_constraints = 1

    design_var_names = [
        'min.node.size',
        'mtry.ratio',
        'num.trees',
        'sample.fraction',
        'trainsize'
    ]
    objective_var_names = [
        'auc',
        'logloss', 
    ]
    constraint_var_names = ['timetrain']
    fidelity_param_name = 'trainsize'

    continuous_inds = list(range(5))
    discrete_inds = []
    categorical_inds = []
     
    _bounds = [
        (1, 100),    # min.node.size
        (0.0, 1.0),  # mtry.ratio
        (1, 2000),   # num.trees
        (0.1, 1.0),  # sample.fraction
        (0.03, 1.0), # trainsize (fidelity parameter)
    ]
    _ref_point = [-0.5, 7.0]
    # Approximated using NSGA-II with with 
    #   population_size=250
    #   max_gen=100
    #   seeds=random.sample(range(1000), 100)
    # For seed 581
    _max_hv = 3.1245282643965755

    def __init__(
        self,
        noise_std: None | float | list[float] = None,
        negate: bool = False,
        dtype: torch.dtype = torch.double,
    ) -> None:
        r"""
        Args:
            noise_std: Standard deviation of the observation noise.
            negate: If True, negate the objectives.
            dtype: The dtype that is used for the bounds of the function.
        """
        super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
        self._cache_timetrain = None

    def _evaluate_true(self, X) -> Tensor:
        return self._Ranger_HPO(X)
    
    def _evaluate_slack_true(self, X: Tensor) -> Tensor:
        """
        Evaluate the constraint slack for input tensor X.
        Uses cached timetrain from _Ranger_HPO to avoid re-running the benchmark.
        
        Args:
            X: A tensor of shape (..., n_features) with hyperparameter configurations.
            
        Returns:
            A tensor of shape (..., 1) containing the constraint slack values.
            Constraint: timetrain <= 100 (training time should be less than 100 seconds)
        """
        if self._cache_timetrain is None:
            # If not cached, evaluate to get the constraint values
            _ = self._Ranger_HPO(X)
        
        # Convert cached numpy array to tensor
        timetrain_tensor = torch.tensor(
            self._cache_timetrain, 
            dtype=X.dtype, 
            device=X.device
        )
        
        # Constraint slack: negative values indicate constraint violation
        # We want timetrain <= 1.3, so slack = 1.3 - timetrain
        slack = 1.3 - timetrain_tensor
        
        # Return with shape (..., 1)
        return slack.unsqueeze(-1)

    def _Ranger_HPO(self, X: Tensor) -> Tensor:
        """
        The Ranger algorithm (fast implementation of Random Forests) Hyperparameter 
        Optimization problem using YAHPO Gym.

        Args:
            X: Tensor of shape (..., n_features) with hyperparameter configurations.
            Supports arbitrary batch dimensions.

        Returns: 
            Tensor: shape (..., n_objectives) with objective values
                    - Objective 1: auc (maximization)
                    - Objective 2: logloss (minimization)
        """
        # Select a Benchmark
        bench = benchmark_set.BenchmarkSet("iaml_ranger")
        bench.set_instance("1489")
        
        # Store original shape and flatten batch dimensions
        original_shape = X.shape[:-1]  # All dimensions except last (features)
        X_flat = X.reshape(-1, X.shape[-1])  # Flatten to (n_samples, n_features)
        
        # Convert tensor to numpy for YAHPO
        X_np = X_flat.detach().cpu().numpy().astype(np.float64)
        
        # Parameter names (excluding booster and task_id)
        param_names = [
            'min.node.size',
            'mtry.ratio',
            'num.trees',
            'sample.fraction',
            'trainsize'
        ]
        
        # Create configuration dict for YAHPO
        configs = []
        for i in range(X_np.shape[0]):
            config = {
                param_names[j]: float(X_np[i, j]) for j in range(len(param_names))
            }
            # Convert integers
            config['min.node.size'] = int(config['min.node.size'])
            config['num.trees'] = int(config['num.trees'])
            config['replace'] = 'TRUE'  
            config['respect.unordered.factors'] = 'ignore'
            config['splitrule'] = 'gini'
            config['task_id'] = '1489'   # fixed instance
            
            configs.append(config)
        
        # Evaluate using YAHPO bench
        results = bench.objective_function(configs)

        # Extract objectives (mmce and rammodel)
        results = pd.DataFrame(results)
        auc = results['auc'].values
        logloss = results['logloss'].values
        self._cache_timetrain = results['timetrain'].values

        objectives = np.column_stack([-auc, logloss])

        # Convert back to tensor and reshape to original batch dimensions
        objectives_tensor = torch.tensor(
            objectives, dtype=X.dtype, device=X.device)
        # (..., n_objectives)        
        return objectives_tensor.reshape(*original_shape, -1)  