import torch
import numpy as np


from .mixture_utils import (
    scale_nll,
    get_log_prob,
    flow_zuko_gen,
    flow_flowtorch_gen,
)
from .flow_experts import FlowMixtureExperts
from .neural_rules import GMMRemixer
from .training_config import TrainingConfig

from .mixture_model_flows import FlowMixtureModel
from .mixture_model_remix import RemixMixtureModel

__all__ = [
    "compute_test_nll_comprehensive",
    "scale_nll",
    "get_log_prob",
    "flow_zuko_gen",
    "flow_flowtorch_gen",
    "FlowMixtureModel",
    "RemixMixtureModel",
    "TrainingConfig",
]


def compute_test_nll_comprehensive(
    X_test: torch.Tensor,
    Y_test: torch.Tensor,
    mixture_rules,
    model_components,
    model_type: str = "auto",
    gmm_model=None,
    batch_size: int = -1,
    device: torch.device = torch.device("cpu"),
) -> float:
    """Comprehensive test NLL computation for both flow and GMM remix models."""
    if batch_size == -1:
        batch_size = X_test.shape[0]
    if model_type == "auto":
        model_type = (
            "gmm_remix" if isinstance(model_components, GMMRemixer) else "standard"
        )

    X_test, Y_test = X_test.to(device), Y_test.to(device)
    total_nll, n_batches = 0.0, 0
    with torch.no_grad():
        for i in range(0, X_test.shape[0], batch_size):
            X_batch, Y_batch = X_test[i : i + batch_size], Y_test[i : i + batch_size]
            rule_probs, _ = mixture_rules(X_batch)
            if model_type == "standard":
                # component_log_probs = torch.stack(
                #     [get_log_prob(flow, Y_batch) for flow in model_components], dim=1
                # )
                # log_rule_probs = torch.log(rule_probs + 1e-8)
                experts = FlowMixtureExperts(model_components)
                experts.disabled_mask = torch.tensor(
                    mixture_rules.get_disabled_rules(), dtype=torch.bool, device=device
                )
                experts.to(device)
                component_log_probs = experts(Y_batch)
                component_log_probs = torch.nan_to_num(component_log_probs, nan=-1e8)
                log_rule_probs = torch.log(rule_probs + 1e-8)
                mixture_log_likelihood = torch.logsumexp(
                    component_log_probs + log_rule_probs, dim=1
                )
                nll = -torch.mean(mixture_log_likelihood)
            elif model_type == "gmm_remix":
                Y_batch_np = Y_batch.cpu().numpy()
                if Y_batch_np.ndim == 1:
                    Y_batch_np = Y_batch_np.reshape(-1, 1)
                densities = np.exp(gmm_model._estimate_log_prob(Y_batch_np))
                component_densities = torch.tensor(
                    densities, dtype=torch.float32, device=device
                )
                nll, _, _ = model_components(rule_probs, component_densities)
            total_nll += nll.item()
            n_batches += 1
    return total_nll / n_batches
