"""Lightweight property prediction head for VAE encoder latents.

This module is used BOTH during VAE fine-tuning (to orient the latent space)
AND during inference-time guidance (to steer generation toward target properties).
"""

from __future__ import annotations
from typing import Sequence

import torch
import torch.nn as nn

from moltenflow.models.vae import TokenPooling

# Supported aggregation methods for latent tokens
AGGREGATION_METHODS = ("mean", "first", "flatten", "attention")


class BoundedOutput(nn.Module):
    """Apply bounded output transformation using scaled sigmoid.

    For each property with bounds [low, high], computes:
        output = low + (high - low) * sigmoid(raw_output)

    Properties without bounds (None) pass through unchanged.

    This transformation is differentiable everywhere and smoothly maps
    unbounded network outputs to the specified range, making it suitable
    for gradient-based optimization.

    Args:
        bounds: List of (low, high) tuples or None for each output dimension.
               None entries indicate unbounded properties.
    """

    def __init__(self, bounds: Sequence[tuple[float, float] | None]):
        super().__init__()
        self.bounds = list(bounds)
        self.n_outputs = len(bounds)

        # Pre-compute which outputs are bounded
        self.bounded_mask = [b is not None for b in bounds]
        self.has_bounded = any(self.bounded_mask)

        if self.has_bounded:
            # Store bounds as tensors for efficient computation
            lows = []
            ranges = []
            for b in bounds:
                if b is not None:
                    lows.append(b[0])
                    ranges.append(b[1] - b[0])
                else:
                    lows.append(0.0)
                    ranges.append(1.0)

            self.register_buffer("lows", torch.tensor(lows, dtype=torch.float32))
            self.register_buffer("ranges", torch.tensor(ranges, dtype=torch.float32))
            self.register_buffer("mask", torch.tensor(self.bounded_mask, dtype=torch.bool))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply bounded transformation.

        Args:
            x: Raw network output of shape (B, n_outputs)

        Returns:
            Transformed output with bounded properties constrained to their ranges
        """
        if not self.has_bounded:
            return x

        # Apply scaled sigmoid only to bounded properties
        # output = low + range * sigmoid(x) for bounded
        # output = x for unbounded
        sigmoid_x = torch.sigmoid(x)
        bounded = self.lows + self.ranges * sigmoid_x

        # Combine bounded and unbounded using mask
        result = torch.where(self.mask.unsqueeze(0), bounded, x)
        return result

    def extra_repr(self) -> str:
        bounded_str = ", ".join(
            f"{i}:{b}" if b else f"{i}:unbounded" for i, b in enumerate(self.bounds)
        )
        return f"bounds=[{bounded_str}]"


class SurrogateHead(nn.Module):
    """Property prediction head for VAE encoder latents with optional conditions.

    Predicts molecular properties from latent token representations z (B, K, D)
    and optional conditional variables c (e.g., temperature, pressure).
    Aggregates over K tokens before prediction to enable gradient flow through
    the VAE encoder during property-aware fine-tuning.

    This model serves dual purposes:
    1. During VAE fine-tuning: enables property loss gradients to backpropagate
       through the encoder, orienting the latent space.
    2. During inference: computes property gradients for guidance-based generation.

    Aggregation Methods:
        - 'mean': Average over K tokens -> (B, D). Simple, parameter-free.
        - 'first': Take first token -> (B, D). CLS-token style, parameter-free.
        - 'flatten': Concatenate K tokens -> (B, K*D). Preserves all information.
        - 'attention': Learned attention pooling -> (B, D). Uses TokenPooling with
          a single learned query to attend over K tokens.

    Output Bounds:
        When `output_bounds` is provided, bounded properties use scaled sigmoid:
            output = low + (high - low) * sigmoid(raw_output)
        This ensures outputs stay within [low, high] while remaining differentiable.

    Args:
        K: Number of latent tokens
        d_latent: Latent token dimensionality
        out_dim: Number of properties to predict
        cond_dim: Conditional variable dimensionality (0 = no conditions)
        hidden_dim: Hidden layer dimension for MLP
        aggregation: Aggregation method ('mean', 'first', 'flatten', or 'attention')
        dropout: Dropout probability
        output_bounds: Optional list of (low, high) tuples for each property.
                      Use None for unbounded properties. Length must match out_dim.
                      Example: [(0.0, 1.0), (1.0, 10.0), None] for 3 properties.
    """

    def __init__(
        self,
        K: int,
        d_latent: int,
        out_dim: int,
        cond_dim: int = 0,
        hidden_dim: int = 256,
        aggregation: str = "mean",
        dropout: float = 0.1,
        output_bounds: Sequence[tuple[float, float] | None] | None = None,
    ):
        super().__init__()

        if aggregation not in AGGREGATION_METHODS:
            raise ValueError(
                f"Unknown aggregation: {aggregation}. Use one of {AGGREGATION_METHODS}."
            )

        if output_bounds is not None and len(output_bounds) != out_dim:
            raise ValueError(
                f"output_bounds length ({len(output_bounds)}) must match out_dim ({out_dim})"
            )

        self.K = K
        self.d_latent = d_latent
        self.out_dim = out_dim
        self.cond_dim = cond_dim
        self.aggregation = aggregation

        # Determine input dimension after aggregation
        if aggregation == "flatten":
            latent_dim = K * d_latent
        else:
            # mean, first, attention all produce (B, D)
            latent_dim = d_latent

        # Create attention pooling module if needed
        # Uses single query (K=1) to pool K tokens down to 1 token
        self.attention_pool = None
        if aggregation == "attention":
            self.attention_pool = TokenPooling(d_model=d_latent, K=1)

        # Total input dimension includes aggregated latent + conditions
        in_dim = latent_dim + cond_dim

        # MLP layers
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )

        # Output bounding transformation
        self.output_transform: BoundedOutput | None = None
        if output_bounds is not None:
            self.output_transform = BoundedOutput(output_bounds)

    def forward(self, z: torch.Tensor, c: torch.Tensor | None = None) -> torch.Tensor:
        """Predict properties from latent tokens.

        Args:
            z: Latent tokens of shape (B, K, D)
            c: Optional conditional variables of shape (B, cond_dim)

        Returns:
            Property predictions of shape (B, out_dim)
        """
        B, K, D = z.shape

        if K != self.K or D != self.d_latent:
            raise ValueError(
                f"Expected z of shape (B, {self.K}, {self.d_latent}), got (B, {K}, {D})"
            )

        # Aggregate tokens based on method
        if self.aggregation == "mean":
            z_agg = z.mean(dim=1)  # (B, D)
        elif self.aggregation == "first":
            z_agg = z[:, 0, :]  # (B, D) - take first token
        elif self.aggregation == "flatten":
            z_agg = z.reshape(B, -1)  # (B, K*D)
        elif self.aggregation == "attention":
            # Use attention pooling with single query -> (B, 1, D) -> squeeze to (B, D)
            z_agg = self.attention_pool(z).squeeze(1)  # (B, D)
        else:
            raise ValueError(f"Unknown aggregation: {self.aggregation}")

        # Concatenate conditions if provided
        if self.cond_dim > 0:
            if c is None:
                raise ValueError(
                    f"Model expects {self.cond_dim} conditional variables but got None"
                )
            if c.shape[1] != self.cond_dim:
                raise ValueError(f"Expected {self.cond_dim} conditional dims, got {c.shape[1]}")
            x = torch.cat([z_agg, c], dim=1)
        else:
            if c is not None:
                raise ValueError("Model has cond_dim=0 but conditional variables were provided")
            x = z_agg

        # Predict properties
        out = self.net(x)

        # Apply output bounds if configured
        if self.output_transform is not None:
            out = self.output_transform(out)

        return out
