#!/usr/bin/env python3

from __future__ import annotations

from typing import Callable

import torch
from torch import Tensor
from botorch.utils.sampling import draw_sobol_samples
from botorch.test_functions.base import (
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.utils.transforms import unnormalize


class GenerateInitialSample:
    def __init__(
        self, 
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem, 
        bounds: Tensor, 
        has_constraints: bool
    ):
        r"""
        Initialize the GenerateInitialSample class.

        Args:
            problem: The optimization problem to solve.
            bounds: The bounds for the design variables.
            has_constraints: Whether the problem has constraints.
                NOTE: If has_constraints is True, the problem must implement
                the `evaluate_slack` method.
        """
        self.problem = problem
        self.bounds = bounds
        self.has_constraints = has_constraints

        if self.has_constraints:
            if not hasattr(self.problem, 'evaluate_slack'):
                raise ValueError("Constraints are defined but the problem does not " \
                                 "have an `evaluate_slack` method.")
            

    def generate_initial_data(
        self, 
        n: int,
        seed: int | None = None
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r"""
        Generate initial training data.
        
        Args:
            n: Number of initial samples
            
        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d` where last 
                    column of `d` corresponds to fidelity levels.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c`
                    (None if no constraints).
        """
        
        train_x = draw_sobol_samples(bounds=self.bounds, n=n, q=1,
                                     seed=seed).squeeze(1)
        train_obj = self.problem(train_x)
        # Handle constraints if present
        train_constraints = None
        if self.has_constraints:
            # Negative values imply feasibility in botorch
            train_constraints = -self.problem.evaluate_slack(train_x)
        return train_x, train_obj, train_constraints    
    

    def generate_initial_data_by_fidelity(
        self,
        n: int,
        fidelities: list[float],
        seed: int | None = None
    ):
        r""" 
        Generate initial data for each fidelity level.
        We get samples to collect from each fidelity by deviding
        n by number of fidelities.

        Args:
            n (int): Total number of samples to generate.
            fidelities (list[float]): List of fidelity levels to sample from.
            seed (int | None): Random seed for reproducibility.

        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d` where last 
                    column of `d` corresponds to fidelity levels.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c`
                    (None if no constraints).
        """
        num_fidelities = len(fidelities)
        samples_per_fidelity = n // num_fidelities
        remainder = n % num_fidelities

        train_x_list = []
        for i, fidelity in enumerate(fidelities):
            current_n = samples_per_fidelity + (1 if i < remainder else 0)
            if current_n == 0:
                continue
            # Generate samples for design variables
            x_samples = draw_sobol_samples(
                bounds=self.bounds[:, :-1], 
                n=current_n, 
                q=1,
                seed=seed
            ).squeeze(1)
            # Append fidelity level as the last column
            fidelity_column = torch.full(
                (current_n, 1), 
                fidelity, 
                dtype=x_samples.dtype, 
                device=x_samples.device
            )
            x_with_fidelity = torch.cat([x_samples, fidelity_column], dim=-1)
            train_x_list.append(x_with_fidelity)

        train_x = torch.cat(train_x_list, dim=0)
        train_obj = self.problem(train_x)
        train_constraints = None
        if self.has_constraints:
            train_constraints = -self.problem.evaluate_slack(train_x)
        return train_x, train_obj, train_constraints

    def generate_initial_x_by_fidelity(
        self,
        n: int,
        fidelities: list[float],
        seed: int | None = None
    ):
        r""" 
        Generate initial data for each fidelity level.
        We get samples to collect from each fidelity by deviding
        n by number of fidelities.

        Args:
            n (int): Total number of samples to generate.
            fidelities (list[float]): List of fidelity levels to sample from.
            seed (int | None): Random seed for reproducibility.

        Returns:
            Tensor: 
                - train_x: Input samples with dim `n x d` where last 
        """
        num_fidelities = len(fidelities)
        samples_per_fidelity = n // num_fidelities
        remainder = n % num_fidelities

        train_x_list = []
        for i, fidelity in enumerate(fidelities):
            current_n = samples_per_fidelity + (1 if i < remainder else 0)
            if current_n == 0:
                continue
            # Generate samples for design variables
            x_samples = draw_sobol_samples(
                bounds=self.bounds[:, :-1], 
                n=current_n, 
                q=1,
                seed=seed
            ).squeeze(1)
            # Append fidelity level as the last column
            fidelity_column = torch.full(
                (current_n, 1), 
                fidelity, 
                dtype=x_samples.dtype, 
                device=x_samples.device
            )
            x_with_fidelity = torch.cat([x_samples, fidelity_column], dim=-1)
            train_x_list.append(x_with_fidelity)

        train_x = torch.cat(train_x_list, dim=0)
        return train_x

    def generate_initial_data_multifidelity(
        self,
        n_full_fidelity_equiv: float, 
        cost_fn: Callable[[Tensor], Tensor],
        is_discrete: bool = False,
        fidelity_levels: Tensor | None = None,
        n_grid: int | None = 1000,
        verbose: bool = False
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r"""
        Method to generate initial data for both continuous 
        and discrete fidelity cases. Sample fidelities inversely 
        proportional to cost.
        
        Assumptions:
            - Fidelity range: s ∈ [0, 1] where 0 is lowest fidelity and 1 is highest fidelity.
            - Cost function C(s) is monotonically increasing with fidelity s.
            - Full fidelity corresponds to s = 1.0.
        
        Mathematical formulation:
            Budget constraint: Σᵢ C(xᵢ) ≤ B
            where B = n_full_fidelity_equiv × C(s_max)
            
        Sampling strategy:
            For each sample xᵢ:
            1. Sample design variables uniformly: x ~ U(bounds)
            2. Sample fidelity s from p(s) ∝ 1/C(s)
            3. Continue until budget B is exhausted
        
        Args:
            n_full_fidelity_equiv (float): Budget in terms of number 
                of full-fidelity evaluations.
            cost_fn (Callable[[Tensor], Tensor]): Function that returns 
                cost given input tensor.
            is_discrete (bool): If True, use discrete fidelity levels; 
                if False, continuous.
            fidelity_levels (Tensor | None): Tensor of available discrete 
                fidelity levels (required if is_discrete=True).
            n_grid (int | None): Number of grid points for numerical approximation 
                (required if is_discrete=False).
            
        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d` where last 
                    column of `d` corresponds to fidelity levels.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c`
                    (None if no constraints).
        """
        if is_discrete and fidelity_levels is None:
            raise ValueError("fidelity_levels must be provided when is_discrete=True")
        if not is_discrete and n_grid is None:
            raise ValueError("n_grid must be provided when is_discrete=False")
            
        train_x = torch.empty(
            0, self.bounds.shape[1], 
            dtype=self.bounds.dtype, 
            device=self.bounds.device
        )
        total_cost = 0
        
        # Calculate total cost limit based on n_full_fidelity_equiv
        if is_discrete:
            max_fidelity_input = torch.zeros(
                1, self.bounds.shape[1],
                dtype=self.bounds.dtype, 
                device=self.bounds.device
            )
            max_fidelity_input[:, -1] = fidelity_levels.max()
            full_fidelity_cost = cost_fn(max_fidelity_input).item()
        else:
            full_fidelity_input = torch.ones(
                1, 
                self.bounds.shape[1], 
                dtype=self.bounds.dtype, 
                device=self.bounds.device
            )
            full_fidelity_cost = cost_fn(full_fidelity_input).item()
            
        total_cost_limit = n_full_fidelity_equiv * full_fidelity_cost

        while total_cost < total_cost_limit:
            # Sample design variables
            new_x = torch.rand(1, self.bounds.shape[1], 
                            dtype=self.bounds.dtype, 
                            device=self.bounds.device
                    )
            
            # Sample fidelity level inversely proportional to cost
            new_x[:, -1] = self.inv_transform(
                new_x[:, -1], 
                cost_fn, 
                is_discrete=is_discrete,
                fidelity_levels=fidelity_levels,
                n_grid=n_grid
            )
            
            # Calculate cost and check budget
            new_cost = cost_fn(new_x).item()
            if total_cost + new_cost > total_cost_limit:
                break
                
            total_cost += new_cost
            train_x = torch.cat([train_x, new_x], dim=0)
            if verbose:
                print(
                    f"curr_cost/cost_limit: {total_cost:.2f}/{total_cost_limit:.2f}"
                )
        if is_discrete:
            # For discrete fidelity, unnormalize only the design variables 
            # (all columns except last)
            train_x[:, :-1] = unnormalize(train_x[:, :-1], bounds=self.bounds[:, :-1])
        else:
            # For continuous fidelity, unnormalize all variables
            train_x = unnormalize(train_x, bounds=self.bounds)
        
        train_obj = self.problem(train_x)
        train_constraints = None
        if self.has_constraints:
            train_constraints = -self.problem.evaluate_slack(train_x)
            
        return train_x, train_obj, train_constraints

    def inv_transform(
        self, 
        u: Tensor, 
        cost_fn: Callable[[Tensor], Tensor], 
        is_discrete: bool = False,
        fidelity_levels: Tensor | None = None,
        n_grid: int | None = None
    ):
        r""" 
        Inverse transform to sample from probability distribution 
        with PDF proportional to 1/cost, for any cost function.
        
        Assumptions:
            - Fidelity range: s ∈ [0, 1] where 0 is lowest fidelity 
              and 1 is highest fidelity.
            - Cost function C(s) is monotonically increasing with fidelity s.
            - Full fidelity corresponds to s = 1.0.
        
        Mathematical formulation:
            p(s) ∝ 1/C(s)
            where C(s) is the cost function and s is the fidelity level.
            
        Discrete case:
            p(s_i) = (1/C(s_i)) / Σ_j(1/C(s_j)) for i ∈ {1, ..., K}
            where s_i are the discrete fidelity levels.
            
        Continuous case:
            Numerical approximation using grid discretization:
            - Create grid: s_1, s_2, ..., s_N ∈ [0, 1]
            - Evaluate costs: C(s_i) for all grid points
            - Compute probabilities: p(s_i) = (1/C(s_i)) / Σ_j(1/C(s_j))
            - Build CDF: F(s_i) = Σ_{j≤i} p(s_j)
            - Sample via inverse CDF: s = F⁻¹(u)
        
        Args:
            u (Tensor): Uniform(0,1) random variable.
            cost_fn (Callable[[Tensor], Tensor]): The actual cost function 
                to use for inverse sampling.
            is_discrete (bool): If True, use discrete fidelity levels; 
                if False, continuous.
            fidelity_levels (Tensor | None): Tensor of available discrete fidelity 
                levels (required if is_discrete=True).
            n_grid (int | None): Number of grid points for numerical approximation 
                (required if is_discrete=False).
        """
        # Validation
        if is_discrete:
            if fidelity_levels is None:
                raise ValueError("fidelity_levels must be provided when is_discrete=True")
        else:
            if n_grid is None:
                raise ValueError("n_grid must be provided when is_discrete=False")
        
        # Determine fidelity points to evaluate
        if is_discrete:
            # Ensure fidelity_levels is on the correct device and dtype
            fidelity_points = fidelity_levels.to(
                dtype=self.bounds.dtype, device=self.bounds.device
            )
        else:
            fidelity_points = torch.linspace(
                0, 1, n_grid, 
                dtype=self.bounds.dtype, 
                device=self.bounds.device
            )

        # Create input tensors for cost evaluation
        sample_inputs = torch.zeros(
            len(fidelity_points), self.bounds.shape[1],
            dtype=self.bounds.dtype, 
            device=self.bounds.device
        )
        sample_inputs[:, -1] = fidelity_points
        
        # Evaluate costs at all fidelity points
        costs = torch.tensor(
            [cost_fn(inp.unsqueeze(0)).item() for inp in sample_inputs],
            dtype=self.bounds.dtype,
            device=self.bounds.device
        )
        
        # Compute probabilities inversely proportional to cost
        inv_costs = 1.0 / costs
        probabilities = inv_costs / inv_costs.sum()
        
        # Compute cumulative probabilities
        cum_probs = torch.cumsum(probabilities, dim=0)
        
        # Sample based on uniform random variable using inverse CDF
        indices = torch.searchsorted(cum_probs, u, right=True)
        indices = torch.clamp(indices, 0, len(fidelity_points) - 1)
        
        return fidelity_points[indices]