import torch
import torch.nn as nn
from sbi.neural_nets import posterior_nn
from sbi.neural_nets.embedding_nets import FCEmbedding

class MeanEmbedding(nn.Module):
    def __init__(self, n_obs: int, dim: int):
        super().__init__()
        self.n_obs = n_obs
        self.dim = dim
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (batch, n_obs * dim) -> (batch, n_obs, dim)
        x = x.view(-1, self.n_obs, self.dim)
        return x.mean(dim=1)


class IdentityEmbedding(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x

def build_npe_model(
    theta_sample: torch.Tensor,
    x_sample: torch.Tensor,
    embedding_type: str = "mean", 
    n_obs: int = None,            
    dim: int = None,              
    embedding_dim: int = None,    
    embedding_hidden: int = None,
    embedding_layers: int = None,
    num_transforms: int = None,
    hidden_features: int = None,
):
    x_dim = x_sample.shape[-1]
    
    if embedding_type == "none" or embedding_type == "identity":
        embedding_net = nn.Identity()
    
    elif embedding_type == "mean":
        if n_obs is None or dim is None:
            raise ValueError("n_obs and dim required for mean embedding")
        embedding_net = MeanEmbedding(n_obs=n_obs, dim=dim)
    
    elif embedding_type == "fc":
        if embedding_dim is None:
            raise ValueError("embedding_dim required for fc embedding")
        embedding_kwargs = {"input_dim": x_dim, "output_dim": embedding_dim}
        if embedding_layers is not None:
            embedding_kwargs["num_layers"] = embedding_layers
        if embedding_hidden is not None:
            embedding_kwargs["num_hiddens"] = embedding_hidden
        embedding_net = FCEmbedding(**embedding_kwargs)
    
    else:
        raise ValueError(f"Unknown embedding_type: {embedding_type}")
    
    # Build NSF
    flow_kwargs = {"model": "nsf", "embedding_net": embedding_net}
    if num_transforms is not None:
        flow_kwargs["num_transforms"] = num_transforms
    if hidden_features is not None:
        flow_kwargs["hidden_features"] = hidden_features
    
    builder = posterior_nn(**flow_kwargs)
    return builder(theta_sample, x_sample)

def get_embedding_net(density_estimator):
    return density_estimator.net._embedding_net

def compute_summary_statistics(
    embedding_net: nn.Module,
    X_flat: torch.Tensor,
    device: str = "cpu",
) -> torch.Tensor:
    embedding_net.eval()
    with torch.no_grad():
        S = embedding_net(X_flat.to(device))
    return S.cpu()