import torch
import numpy as np
from typing import List, Optional

from .flow_experts import FlowMixtureExperts
from .training_config import TrainingConfig
from .mixture_model_base import MixtureModel
from .mixture_utils import (
    flow_zuko_gen,
    create_flow,
    _flow_data_cleaning,
    pretrain_flows,
)


class FlowMixtureModel(MixtureModel):
    """
    Mixture model with Normalizing Flow experts, refactored into a class-based structure.
    """

    def __init__(self, config: TrainingConfig):
        """Initializes the model, preprocessor, and gating network."""
        super().__init__(config)
        self.rules_optimizer: Optional[torch.optim.Optimizer] = None
        self.expert_optimizer: Optional[torch.optim.Optimizer] = None

    def _initialize_experts(self):
        """
        Initializes the normalizing flow experts and the background component,
        including pre-training if specified in the configuration.
        """
        y_features = self.Y_scaled.shape[1]
        flow_name, kwargs = self.config.flow_gen
        flow_gen = flow_zuko_gen(flow_name, features=y_features, **kwargs)

        base_dist = torch.distributions.Independent(
            torch.distributions.Normal(
                torch.zeros(y_features, device=self.device),
                torch.ones(y_features, device=self.device),
            ),
            1,
        )

        component_flows = [
            create_flow(flow_gen, base_dist, self.device, self.config.lr_flow)[0]
            for _ in range(self.config.n_mixture_components)
        ]

        background_flow = None
        if self.config.use_background_component:
            if self.config.verbose:
                print("\n--- Initializing Background Component ---")
            background_flow, bg_optimizer = create_flow(
                flow_gen, base_dist, self.device, self.config.lr_flow
            )
            if self.config.background_pretrain_epochs > 0:
                # For pre-training, we still need to clean the data passed to the flow
                flow_name, kwargs = self.config.flow_gen
                Y_scaled_cleaned = _flow_data_cleaning(
                    self.Y_scaled, flow_name, mode=kwargs.get("cleaning_mode", "clamp")
                )
                pretrain_flows(
                    [background_flow],
                    [bg_optimizer],
                    self.X_scaled,
                    Y_scaled_cleaned,
                    [torch.arange(Y_scaled_cleaned.shape[0], device=self.device)],
                    n_epochs=self.config.background_pretrain_epochs,
                    batch_size=self.config.batchsize
                    if self.config.batchsize > 0
                    else self.X_scaled.shape[0],
                    device=self.device,
                )
                if self.config.verbose:
                    print("Background component pre-training complete.\n")

        self.expert_model = FlowMixtureExperts(component_flows, background_flow).to(
            self.device
        )

    def _initialize_optimizers(self):
        """Initializes the optimizers for rules and experts."""
        if self.rules_model is None or self.expert_model is None:
            raise RuntimeError("Models must be initialized before optimizers.")
        self.rules_optimizer = torch.optim.Adam(
            self.rules_model.parameters(), lr=self.config.lr_rules
        )
        self.expert_optimizer = torch.optim.Adam(
            self.expert_model.parameters(), lr=self.config.lr_flow
        )

    def _reorder_experts(self, sort_idx: torch.Tensor):
        """
        Reorders the flow experts and their corresponding disabled mask
        according to the provided sort indices.
        """
        if not hasattr(self.expert_model, "component_flows") or not isinstance(
            self.expert_model.component_flows, torch.nn.ModuleList
        ):
            if self.config.verbose:
                print(
                    "Warning: Expert model does not have a 'component_flows' ModuleList. Cannot reorder experts."
                )
            return

        if self.config.verbose:
            print("Reordering flow experts to match sorted rules.")

        sorted_flows = [self.expert_model.component_flows[i] for i in sort_idx]
        self.expert_model.component_flows = torch.nn.ModuleList(sorted_flows)

        with torch.no_grad():
            sorted_mask = self.expert_model.disabled_mask[sort_idx]
            self.expert_model.disabled_mask.data = sorted_mask

    def get_expert_densities(self, y_range: np.ndarray) -> np.ndarray:
        """
        Computes the probability density for each flow expert over a given range of y values.
        """
        if not self.expert_model or not self.preprocessor.scaler_y:
            raise RuntimeError("Model must be trained before computing densities.")

        self.expert_model.eval()
        self.expert_model.to(self.device)

        y_range_scaled = torch.tensor(
            self.preprocessor.scaler_y.transform(y_range),
            dtype=torch.float32,
            device=self.device,
        )

        flow_name, kwargs = self.config.flow_gen
        y_range_scaled_cleaned = _flow_data_cleaning(
            y_range_scaled, flow_name, mode=kwargs.get("cleaning_mode", "clamp")
        )

        with torch.no_grad():
            # expert_model(y) returns log_prob for each component
            component_log_probs = self.expert_model(y_range_scaled_cleaned)
            densities_scaled = torch.exp(component_log_probs)

        # Rescale density for original data space: p_orig(y) = p_scaled(y_s) * |dy_s/dy|
        scaler = self.preprocessor.scaler_y
        if hasattr(scaler, "scale_") and scaler.scale_ is not None:
            scale_product = np.prod(scaler.scale_)
            densities_orig = densities_scaled / scale_product
        else:
            densities_orig = densities_scaled

        return densities_orig.cpu().numpy()

    def _fit_single_model(
        self, X: np.ndarray, Y: np.ndarray, feature_names: Optional[List[str]] = None
    ):
        """
        Main training method that implements the alternating training loop for
        rules and flows, followed by refinement and final calculations.
        """
        self._setup_training(X, Y, feature_names)
        self._initialize_experts()
        self._initialize_optimizers()

        if self.config.verbose:
            print("\n--- Starting Main Training Loop ---")
        batch_size = (
            self.config.batchsize
            if self.config.batchsize > 0
            else self.X_scaled.shape[0]
        )

        for step in range(self.config.component_train_epochs):
            self._anneal_parameters(step)
            self._check_and_update_component_status(step)
            self.expert_model.disable_components(
                [i for i, d in enumerate(self.disabled_components) if d]
            )

            update_rules = (
                step // (self.config.rules_steps + self.config.flow_steps)
            ) % 2 == 0

            idx = torch.randperm(self.X_scaled.shape[0], device=self.device)
            for i in range(0, self.X_scaled.shape[0], batch_size):
                idx_batch = idx[i : i + batch_size]
                X_batch, Y_batch = self.X_scaled[idx_batch], self.Y_scaled[idx_batch]

                if update_rules:
                    self._train_rules_step(X_batch, Y_batch)
                else:
                    self._train_expert_step(X_batch, Y_batch)

            self._log_and_record(step)
        self._log_and_record(self.config.component_train_epochs)
        if self.config.verbose:
            print("\n--- Main Training Loop Finished ---")

        pruned = self._prune_components()
        if pruned:
            self.expert_model.disable_components(
                [i for i, d in enumerate(self.disabled_components) if d]
            )

        merged = self._merge_components()
        if merged:
            self.expert_model.disable_components(
                [i for i, d in enumerate(self.disabled_components) if d]
            )

        if (merged or pruned) and self.config.merge_settle_epochs > 0:
            self._settle_components()

        # Sort components by mean of target variable before final evaluation
        if getattr(self.config, "sort_rules_by_mean", True):
            self.sort_components_by_mean()

        self.rules_model.eval()
        self.expert_model.eval()
        with torch.no_grad():
            final_log_likelihood = self._get_log_likelihood(
                self.X_scaled, self.Y_scaled
            )
            final_nll_raw = -torch.mean(final_log_likelihood).item()

        final_nll_scaled = self.get_nll(self.X_original, self.Y_original)
        self.metrics["raw_nll"] = final_nll_raw
        self._calculate_final_metrics(final_nll_scaled)

    def _get_log_likelihood(
        self, X_tensor: torch.Tensor, Y_tensor: torch.Tensor
    ) -> torch.Tensor:
        """
        Computes log-likelihood log p(y|x) on the scaled data space.
        This is now an internal method, so it does not need to handle scaling
        of the output.
        """
        flow_name, kwargs = self.config.flow_gen
        Y_tensor_cleaned = _flow_data_cleaning(
            Y_tensor, flow_name, mode=kwargs.get("cleaning_mode", "clamp")
        )

        rule_probs, _ = self.rules_model(X_tensor)
        component_log_probs = self.expert_model(Y_tensor_cleaned)

        log_rule_probs = torch.log(rule_probs + 1e-10)

        if torch.any(torch.isnan(component_log_probs)):
            print("NaN detected in component log probabilities, fixing.")
            component_log_probs = torch.nan_to_num(component_log_probs, nan=-1e10)

        log_likelihood = torch.logsumexp(component_log_probs + log_rule_probs, dim=1)
        return log_likelihood

    def _train_rules_step(self, X_batch, Y_batch):
        """Performs a single optimization step on the rules."""
        self.rules_optimizer.zero_grad()
        loss = self._calculate_loss(X_batch, Y_batch)
        if torch.isnan(loss):
            return
        loss.backward()
        self.rules_optimizer.step()

    def _train_expert_step(self, X_batch, Y_batch):
        """Performs a single optimization step on the flows."""
        self.expert_optimizer.zero_grad()
        loss = self._calculate_loss(X_batch, Y_batch)
        if torch.isnan(loss):
            return
        loss.backward()
        self.expert_optimizer.step()

    def _calculate_loss(self, X_batch, Y_batch, return_dict=False):
        """
        Calculates the total loss for a batch, combining NLL and regularization terms.
        """
        log_likelihood = self._get_log_likelihood(X_batch, Y_batch)
        nll_loss = -torch.mean(log_likelihood)

        rule_probs, and_layer_l1_loss = self.rules_model(X_batch)
        interpretable_rule_probs = rule_probs[:, : self.config.n_mixture_components]

        partition_loss = torch.tensor(0.0, device=self.device)
        if self.config.partition_weight > 0:
            partition_loss = torch.mean(torch.sum(interpretable_rule_probs**2, dim=1))

        coverage_loss = torch.tensor(0.0, device=self.device)
        if self.config.coverage_weight > 0:
            mean_responsibility = torch.mean(interpretable_rule_probs, dim=0)
            threshold = 1.0 / (self.config.n_mixture_components * 3.0)
            coverage_loss = torch.sum(torch.relu(threshold - mean_responsibility))

        total_loss = nll_loss
        total_loss -= self.config.partition_weight * partition_loss
        total_loss += self.config.coverage_weight * coverage_loss
        total_loss += self.config.and_layer_entropy * and_layer_l1_loss

        if return_dict:
            return {
                "total": total_loss,
                "nll": nll_loss,
                "partition": partition_loss,
                "coverage": coverage_loss,
                "rule_l1": and_layer_l1_loss,
            }

        return total_loss

    def _merge_components(self) -> bool:
        """Merges adjacent and similar components."""
        if not self.config.merge_components or not self.rules_model:
            return False

        if self.config.verbose:
            print("\n--- Attempting to Merge Similar Components ---")

        merged = self.rules_model.merge_adjacent_components(
            X=self.X_scaled,
            Y=self.Y_scaled,
            density_model=self.expert_model,
            density_model_type="flow",
            iou_threshold=self.config.merge_iou_threshold,
            adjacency_tol=self.config.merge_adjacency_tol,
            density_jsd_threshold=self.config.merge_jsd_threshold,
            verbose=self.config.verbose,
            disabled_components=self.disabled_components,
        )
        if self.config.verbose:
            print(f"Merging complete. Merged status: {merged}")
        return merged

    def _settle_components(self):
        """Runs 'settling' epochs after pruning or merging."""
        if not any(not d for d in self.disabled_components):
            if self.config.verbose:
                print("Skipping settling: all components are disabled.")
            return

        if self.config.verbose:
            print(
                f"\n--- Settling Components for {self.config.merge_settle_epochs} epochs ---"
            )

        batch_size = (
            self.config.batchsize
            if self.config.batchsize > 0
            else self.X_scaled.shape[0]
        )

        for epoch in range(self.config.merge_settle_epochs):
            update_flows = epoch % 2 == 0
            idx = torch.randperm(self.X_scaled.shape[0], device=self.device)

            for i in range(0, self.X_scaled.shape[0], batch_size):
                idx_batch = idx[i : i + batch_size]
                X_batch, Y_batch = (
                    self.X_scaled[idx_batch],
                    self.Y_scaled[idx_batch],
                )

                if update_flows:
                    self._train_expert_step(X_batch, Y_batch)
                else:
                    self._train_rules_step(X_batch, Y_batch)

            current_step = self.config.component_train_epochs + epoch
            self._log_and_record(current_step)

        if self.config.verbose:
            print("--- Settling complete ---")

    def _log_and_record(self, step: int):
        if (
            self.config.verbose
            and step > 0
            and step % (max(1, self.config.component_train_epochs // 10)) == 0
        ) or (
            self.config.record_history_every > 0
            and step % self.config.record_history_every == 0
        ):
            with torch.no_grad():
                losses = self._calculate_loss(
                    self.X_scaled, self.Y_scaled, return_dict=True
                )
                self._log_progress(step, losses)

    def _calculate_final_metrics(self, final_nll: float):
        """Calculates final metrics like AIC and BIC for the flow model."""
        if not self.rules_model or not self.expert_model:
            print("Warning: Cannot calculate metrics. Models not initialized.")
            return

        n_rule_params = self.rules_model.count_active_parameters()

        n_density_params = 0
        for i, flow in enumerate(self.expert_model.component_flows):
            if not self.disabled_components[i]:
                n_density_params += sum(
                    p.numel() for p in flow.parameters() if p.requires_grad
                )

        if self.config.use_background_component and self.expert_model.background_flow:
            n_density_params += sum(
                p.numel()
                for p in self.expert_model.background_flow.parameters()
                if p.requires_grad
            )

        # only use rule params
        total_parameters = n_rule_params  # + n_density_params
        n_samples = self.X_scaled.shape[0]

        if not np.isnan(final_nll):
            aic = 2 * total_parameters + 2 * n_samples * final_nll
            bic = total_parameters * np.log(n_samples) + 2 * n_samples * final_nll
        else:
            aic, bic = np.nan, np.nan

        self.metrics["final_nll"] = final_nll
        self.metrics["aic"] = aic
        self.metrics["bic"] = bic
        self.metrics["total_parameters"] = total_parameters
        self.metrics["n_rule_params"] = n_rule_params
        self.metrics["n_density_params"] = n_density_params

        if self.config.verbose:
            print("\n--- Final Metrics ---")
            print(f"Final Scaled NLL: {final_nll:.4f}")
            print(
                f"Total Active Parameters: {total_parameters} (Rules: {n_rule_params}, Experts: {n_density_params})"
            )
            print(f"AIC: {aic:.2f}")
            print(f"BIC: {bic:.2f}")
            print("---------------------\n")

    def save(self, path: str):
        """Saves the complete state of the trained model."""
        if not self.rules_model or not self.expert_model:
            raise RuntimeError("Cannot save an untrained model.")

        state = {
            "config": self.config,
            "rules_model_state_dict": self.rules_model.state_dict(),
            "expert_model_state_dict": self.expert_model.state_dict(),
            "preprocessor": self.preprocessor,
            "metrics": self.metrics,
            "history": self.history,
            "feature_names": self.feature_names,
            "disabled_components": self.disabled_components,
        }
        torch.save(state, path)

    @classmethod
    def load(cls, path: str, device: str = "cpu") -> "MixtureModel":
        """Loads a model from a file and restores its full state."""
        state = torch.load(path, map_location=device, weights_only=False)
        config = state["config"]

        torch.manual_seed(config.seed)
        np.random.seed(config.seed)

        model = cls(config)
        model.preprocessor = state["preprocessor"]
        model.feature_names = state["feature_names"]
        model.device = torch.device(device)
        model.metrics = state.get("metrics", {})
        model.history = state.get("history", [])
        model.disabled_components = state.get(
            "disabled_components", [False] * config.n_mixture_components
        )

        dummy_x_shape = len(model.feature_names)
        dummy_y_shape = model.preprocessor.scaler_y.n_features_in_
        model.X_scaled = torch.zeros(
            (2, dummy_x_shape), device=model.device, dtype=torch.float32
        )
        model.Y_scaled = torch.zeros(
            (2, dummy_y_shape), device=model.device, dtype=torch.float32
        )

        model._setup_rules_model()
        model._initialize_experts()

        model.rules_model.load_state_dict(state["rules_model_state_dict"])
        model.expert_model.load_state_dict(state["expert_model_state_dict"])

        # Restore disabled state
        if model.rules_model:
            for i, is_disabled in enumerate(model.disabled_components):
                if i < len(model.rules_model.rules):
                    model.rules_model.rules[i].disabled = is_disabled
        if model.expert_model:
            model.expert_model.disable_components(
                [i for i, d in enumerate(model.disabled_components) if d]
            )

        model.rules_model.to(model.device)
        model.expert_model.to(model.device)
        model.rules_model.eval()
        model.expert_model.eval()

        return model
