from __future__ import annotations
from typing import Literal
import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator


def generate_toy_latents(n_samples: int, latent_dim: int, seed: int = 0) -> torch.Tensor:
    """Generate toy latent vectors sampled from N(0, I).

    Args:
        n_samples: Number of samples to generate
        latent_dim: Dimensionality of latent space
        seed: Random seed for reproducibility

    Returns:
        Tensor of shape (n_samples, latent_dim)
    """
    rng = np.random.default_rng(seed)
    z = rng.normal(0.0, 1.0, size=(n_samples, latent_dim)).astype(np.float32)
    return torch.from_numpy(z)


def generate_toy_conditions(n_samples: int, cond_dim: int, seed: int = 0) -> torch.Tensor:
    """Generate toy conditional variables (e.g., temperature, pressure).

    Generates conditions uniformly sampled within reasonable ranges:
    - First dimension: temperature-like [273, 373] K
    - Second dimension: pressure-like [0.1, 10] bar
    - Additional dimensions: uniform [0, 1]

    Args:
        n_samples: Number of samples to generate
        cond_dim: Dimensionality of conditional variables
        seed: Random seed for reproducibility

    Returns:
        Tensor of shape (n_samples, cond_dim)
    """
    if cond_dim == 0:
        return None

    rng = np.random.default_rng(seed)
    conditions = []

    for i in range(cond_dim):
        if i == 0:
            # Temperature-like: [273, 373] K
            c = rng.uniform(273.0, 373.0, size=(n_samples,)).astype(np.float32)
        elif i == 1:
            # Pressure-like: [0.1, 10] bar
            c = rng.uniform(0.1, 10.0, size=(n_samples,)).astype(np.float32)
        else:
            # Generic uniform [0, 1]
            c = rng.uniform(0.0, 1.0, size=(n_samples,)).astype(np.float32)
        conditions.append(c)

    return torch.from_numpy(np.stack(conditions, axis=1))


def generate_toy_targets(
    z: torch.Tensor,
    n_properties: int = 2,
    c: torch.Tensor | None = None,
    seed: int = 0,
) -> torch.Tensor:
    """Generate synthetic property targets with linear + nonlinear components.

    Creates targets as functions of both latent vectors z and optional conditions c:
        y[0] = linear combination of z + conditions + noise
        y[1] = nonlinear function (ReLU, squared terms) + conditions + noise

    Args:
        z: Latent vectors of shape (n_samples, latent_dim)
        n_properties: Number of target properties to generate
        c: Optional conditional variables of shape (n_samples, cond_dim)
        seed: Random seed for reproducibility

    Returns:
        Tensor of shape (n_samples, n_properties)
    """
    rng = np.random.default_rng(seed)
    n_samples, latent_dim = z.shape
    cond_dim = c.shape[1] if c is not None else 0

    # Concatenate z and c for property generation
    if c is not None:
        x = torch.cat([z, c], dim=1)
        x_dim = latent_dim + cond_dim
    else:
        x = z
        x_dim = latent_dim

    # Linear weights for first property
    w1 = torch.from_numpy(rng.normal(0.0, 1.0, size=(x_dim,)).astype(np.float32))
    y1 = x @ w1 + rng.normal(0.0, 0.1, size=(n_samples,)).astype(np.float32)

    if n_properties == 1:
        return y1.unsqueeze(1)

    # Nonlinear function for second property
    w2 = torch.from_numpy(rng.normal(0.0, 0.5, size=(x_dim,)).astype(np.float32))
    y2 = torch.relu(x @ w2) + (x**2).sum(dim=1) * 0.1
    y2 = y2 + torch.from_numpy(rng.normal(0.0, 0.1, size=(n_samples,)).astype(np.float32))

    # Stack and handle additional properties if needed
    y = torch.stack([y1, y2], dim=1)

    if n_properties > 2:
        # Generate additional properties as combinations
        for i in range(2, n_properties):
            w_extra = torch.from_numpy(rng.normal(0.0, 0.5, size=(x_dim,)).astype(np.float32))
            y_extra = torch.tanh(x @ w_extra) + rng.normal(0.0, 0.1, size=(n_samples,)).astype(
                np.float32
            )
            y = torch.cat([y, y_extra.unsqueeze(1)], dim=1)

    return y[:, :n_properties]


def rdkit_fingerprints(smiles: list[str], radius: int = 2, nbits: int = 2048) -> np.ndarray:
    """Generate Morgan (circular) fingerprints using RDKit.

    Args:
        smiles: List of SMILES strings
        radius: Radius for Morgan fingerprint (default: 2 for ECFP4)
        nbits: Number of bits in fingerprint

    Returns:
        Array of shape (n_samples, nbits) with binary fingerprints

    Raises:
        ValueError: If any SMILES string is invalid
    """
    # Create MorganGenerator once for efficiency
    gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=nbits)

    fps = []
    for i, smi in enumerate(smiles):
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            raise ValueError(f"Invalid SMILES at index {i}: {smi}")
        fp = gen.GetFingerprint(mol)
        fps.append(np.array(fp, dtype=np.float32))

    return np.stack(fps, axis=0)


def get_latents(
    smiles: list[str] | None = None,
    backend: Literal["rdkit_fp", "vae"] = "rdkit_fp",
    vae_model=None,
    fp_radius: int = 2,
    fp_nbits: int = 2048,
    device: str = "cpu",
) -> torch.Tensor:
    """Convert SMILES to latent vectors using specified backend.

    Args:
        smiles: List of SMILES strings (required unless backend is toy)
        backend: Encoding backend ("rdkit_fp" or "vae")
        vae_model: Trained VAE model (required if backend="vae")
        fp_radius: Morgan fingerprint radius (for rdkit_fp backend)
        fp_nbits: Morgan fingerprint size (for rdkit_fp backend)
        device: Device for torch tensors

    Returns:
        Tensor of shape (n_samples, latent_dim)

    Raises:
        ValueError: If required arguments are missing or invalid
    """
    if smiles is None or len(smiles) == 0:
        raise ValueError("smiles list cannot be empty")

    if backend == "rdkit_fp":
        z_np = rdkit_fingerprints(smiles, radius=fp_radius, nbits=fp_nbits)
        return torch.from_numpy(z_np).to(device)

    elif backend == "vae":
        if vae_model is None:
            raise ValueError("vae_model is required when backend='vae'")

        # Assume vae_model has an encode method that returns (mu, logvar)
        # and we use mu as the latent representation
        vae_model.eval()
        with torch.no_grad():
            # This is a placeholder - actual implementation depends on VAE interface
            # For now, assume vae_model.encode(smiles) returns mu
            if hasattr(vae_model, "encode"):
                mu, _ = vae_model.encode(smiles)
                return mu.to(device)
            else:
                raise NotImplementedError("VAE model must implement encode() method")

    else:
        raise ValueError(f"Unknown backend: {backend}")
