import torch
import numpy as np
from typing import List, Optional

from sklearn.mixture import GaussianMixture


from .neural_rules import GMMRemixer
from .training_config import TrainingConfig
from .mixture_model_base import MixtureModel


class RemixMixtureModel(MixtureModel):
    """Mixture model with a GMM Remixer expert."""

    def __init__(self, config: TrainingConfig):
        super().__init__(config)
        self.optimizer: Optional[torch.optim.Optimizer] = None
        self.gmm_model: Optional[GaussianMixture] = None
        self.gmm_component_densities_full: Optional[torch.Tensor] = None

    def _initialize_experts(self):
        """Initializes the GMMRemixer expert model."""
        # Determine the actual number of GMM components used after BIC/AIC selection
        n_gmm_components_actual = (
            self.gmm_model.n_components
            if self.gmm_model
            else self.config.n_gmm_components
        )

        self.expert_model = GMMRemixer(
            n_rules=self.config.n_mixture_components,
            n_gmm_components=n_gmm_components_actual,
            use_background_component=self.config.use_background_component,
            diagonal=self.config.diagonal_gmm_init,
        ).to(self.device)

    def _initialize_optimizers(self):
        """Initializes the joint optimizer."""
        self.optimizer = torch.optim.Adam(
            [
                {"params": self.rules_model.parameters(), "lr": self.config.lr_rules},
                {
                    "params": self.expert_model.parameters(),
                    "lr": self.config.lr_gmm_remix,
                },
            ]
        )

    def _reorder_experts(self, sort_idx: torch.Tensor):
        """Reorders the GMM Remixer weights to match the sorted rules."""
        if not hasattr(self.expert_model, "reorder_rules"):
            if self.config.verbose:
                print(
                    "Warning: Expert model does not have a 'reorder_rules' method. Cannot reorder experts."
                )
            return

        if self.config.verbose:
            print("Reordering GMM Remixer weights to match sorted rules.")
        self.expert_model.reorder_rules(sort_idx)

    def get_expert_densities(self, y_range: np.ndarray) -> np.ndarray:
        """
        Computes the probability density for each remixed GMM expert over a
        given range of y values.
        """
        if (
            not self.expert_model
            or not self.gmm_model
            or not self.preprocessor.scaler_y
        ):
            raise RuntimeError("Model must be trained before computing densities.")

        self.expert_model.eval()
        self.expert_model.to(self.device)

        # 1. Scale the input y_range
        y_range_scaled_np = self.preprocessor.scaler_y.transform(y_range)

        # 2. Get base GMM component densities on the scaled range
        base_log_probs_np = self.gmm_model._estimate_log_prob(y_range_scaled_np)
        base_densities = torch.from_numpy(np.exp(base_log_probs_np)).to(
            self.device, dtype=torch.float32
        )

        # 3. Get the remixing weights from the expert model
        with torch.no_grad():
            # get_mixing_weights() returns softmax-normalized weights
            remix_weights = self.expert_model.get_mixing_weights()

        # 4. Compute the remixed densities in the scaled space
        subgroup_densities_scaled = base_densities @ remix_weights

        # 5. Rescale density for original data space: p_orig(y) = p_scaled(y_s) * |dy_s/dy|
        scaler = self.preprocessor.scaler_y
        if hasattr(scaler, "scale_") and scaler.scale_ is not None:
            # For multi-dimensional data, the scaling factor is the product of the scales along each dimension.
            scale_product = np.prod(scaler.scale_)
            densities_orig = subgroup_densities_scaled / scale_product
        else:
            densities_orig = subgroup_densities_scaled

        return densities_orig.cpu().numpy()

    def _fit_single_model(
        self, X: np.ndarray, Y: np.ndarray, feature_names: Optional[List[str]] = None
    ):
        """
        Fits the global GMM, then runs the main training loop, followed by
        refinement steps (pruning, merging, settling) and final metric calculation.
        """
        self._setup_training(X, Y, feature_names)
        self._fit_global_gmm()
        self._initialize_experts()
        self._initialize_optimizers()

        with torch.no_grad():
            y_np = self.Y_scaled.cpu().numpy()
            if y_np.ndim == 1:
                y_np = y_np.reshape(-1, 1)
            densities_np = np.exp(self.gmm_model._estimate_log_prob(y_np))
            self.gmm_component_densities_full = torch.tensor(
                densities_np, dtype=torch.float32, device=self.device
            )

        if self.config.remix_pretrain_epochs > 0:
            if self.config.verbose:
                print("\n--- Pre-training GMM Remixer ---")
            for param in self.rules_model.parameters():
                param.requires_grad = False

            for epoch in range(self.config.remix_pretrain_epochs):
                self.optimizer.zero_grad()
                with torch.no_grad():
                    rule_probs, _ = self.rules_model(self.X_scaled)

                log_likelihoods, l1_penalty = self.expert_model(
                    rule_probs, self.gmm_component_densities_full, mean_reduce=False
                )
                nll_loss = -torch.mean(log_likelihoods)
                pretrain_loss = nll_loss + self.config.gmm_remix_l1_weight * l1_penalty
                pretrain_loss.backward()
                self.optimizer.step()
                if self.config.verbose and (epoch + 1) % 10 == 0:
                    print(
                        f"Pre-train Epoch {epoch+1}, Loss: {pretrain_loss.item():.4f}"
                    )
            for param in self.rules_model.parameters():
                param.requires_grad = True

        if self.config.verbose:
            print("\n--- Starting Main Training Loop ---")
        batch_size = (
            self.config.batchsize
            if self.config.batchsize > 0
            else self.X_scaled.shape[0]
        )
        self.rules_model.train()
        self.expert_model.train()

        for step in range(self.config.component_train_epochs):
            self._anneal_parameters(step)
            self._check_and_update_component_status(step)

            idx = torch.randperm(self.X_scaled.shape[0], device=self.device)
            for i in range(0, self.X_scaled.shape[0], batch_size):
                idx_batch = idx[i : i + batch_size]
                X_batch = self.X_scaled[idx_batch]
                densities_batch = self.gmm_component_densities_full[idx_batch]
                self._train_step(X_batch, densities_batch)

            self._log_and_record(step)

        self._log_and_record(self.config.component_train_epochs)
        if self.config.verbose:
            print("\n--- Main Training Loop Finished ---")

        pruned = self._prune_components()
        merged = self._merge_components()

        if (merged or pruned) and self.config.merge_settle_epochs > 0:
            self._settle_components()

        if getattr(self.config, "sort_rules_by_mean", True):
            self.sort_components_by_mean()
        # --- 6. Final Calculations ---
        self.rules_model.eval()
        self.expert_model.eval()
        with torch.no_grad():
            final_log_likelihood = self._get_log_likelihood(
                self.X_scaled, self.Y_scaled
            )
            final_nll_raw = -torch.mean(final_log_likelihood).item()

        final_nll_scaled = self.get_nll(self.X_original, self.Y_original)
        self.metrics["raw_nll"] = final_nll_raw
        self._calculate_final_metrics(final_nll_scaled)

    def _log_and_record(self, step: int):
        if (
            self.config.verbose
            and step > 0
            and step % (max(1, self.config.component_train_epochs // 10)) == 0
        ) or (
            self.config.record_history_every > 0
            and step % self.config.record_history_every == 0
        ):
            with torch.no_grad():
                losses = self._calculate_loss(
                    self.X_scaled,
                    self.gmm_component_densities_full,
                    return_dict=True,
                )
                self._log_progress(step, losses)

    def _merge_components(self) -> bool:
        """
        Merges adjacent and similar components by calling the corresponding
        method on the rules_model.
        """
        if not self.config.merge_components or not self.rules_model:
            return False

        if self.config.verbose:
            print("\n--- Attempting to Merge Similar Components ---")

        if not hasattr(self.rules_model, "merge_adjacent_components"):
            if self.config.verbose:
                print(
                    "Warning: rules_model does not have 'merge_adjacent_components' method. Skipping merge."
                )
            return False

        merged = self.rules_model.merge_adjacent_components(
            X=self.X_scaled,
            Y=self.Y_scaled,
            density_model=self.expert_model,
            gmm_model=self.gmm_model,
            density_model_type="gmm_remix",
            iou_threshold=self.config.merge_iou_threshold,
            adjacency_tol=self.config.merge_adjacency_tol,
            density_jsd_threshold=self.config.merge_jsd_threshold,
            verbose=self.config.verbose,
            disabled_components=self.disabled_components,
        )
        if self.config.verbose:
            print(f"Merging complete. Merged status: {merged}")
        return merged

    def _settle_components(self):
        """
        Runs 'settling' epochs after pruning or merging to allow the
        model parameters to stabilize.
        """
        if not any(not d for d in self.disabled_components):
            if self.config.verbose:
                print("Skipping settling: all components are disabled.")
            return

        if self.config.verbose:
            print(
                f"\n--- Settling Components for {self.config.merge_settle_epochs} epochs ---"
            )

        batch_size = (
            self.config.batchsize
            if self.config.batchsize > 0
            else self.X_scaled.shape[0]
        )
        self.rules_model.train()
        self.expert_model.train()

        for epoch in range(self.config.merge_settle_epochs):
            idx = torch.randperm(self.X_scaled.shape[0], device=self.device)

            for i in range(0, self.X_scaled.shape[0], batch_size):
                idx_batch = idx[i : i + batch_size]
                X_batch = self.X_scaled[idx_batch]
                densities_batch = self.gmm_component_densities_full[idx_batch]
                self._train_step(X_batch, densities_batch)

            # Logging and History Recording for settling phase
            current_step = self.config.component_train_epochs + epoch
            self._log_and_record(current_step)

        if self.config.verbose:
            print("--- Settling complete ---")

    def _calculate_final_metrics(self, final_nll: float):
        """
        Calculates final metrics for the GMM-Remix model, accounting for
        parameters from the rules, the expert, and the global GMM.
        """
        if not self.rules_model or not self.expert_model or not self.gmm_model:
            print("Warning: Cannot calculate metrics. Models not initialized.")
            return

        n_rule_params = self.rules_model.count_active_parameters()

        disabled_fraction = sum(self.disabled_components) / len(
            self.disabled_components
        )
        n_expert_params = int(
            sum(p.numel() for p in self.expert_model.parameters() if p.requires_grad)
            * (1 - disabled_fraction)
        )

        n_gmm_params = self.gmm_model._n_parameters()

        # only use rule parameters
        total_parameters = n_rule_params
        n_samples = self.X_scaled.shape[0]

        if not np.isnan(final_nll):
            aic = 2 * total_parameters + 2 * n_samples * final_nll
            bic = total_parameters * np.log(n_samples) + 2 * n_samples * final_nll
        else:
            aic, bic = np.nan, np.nan

        self.metrics["final_nll"] = final_nll
        self.metrics["aic"] = aic
        self.metrics["bic"] = bic
        self.metrics["total_parameters"] = total_parameters
        self.metrics["n_rule_params"] = n_rule_params
        self.metrics["n_expert_params"] = n_expert_params
        self.metrics["n_gmm_params"] = n_gmm_params

        if self.config.verbose:
            print("\n--- Final Metrics ---")
            print(f"Final Scaled NLL: {final_nll:.4f}")
            print(
                f"Total Active Parameters: {total_parameters} (Rules: {n_rule_params}, Expert: {n_expert_params}, GMM: {n_gmm_params})"
            )
            print(f"AIC: {aic:.2f}")
            print(f"BIC: {bic:.2f}")
            print("---------------------\n")

    def _get_log_likelihood(
        self, X_tensor: torch.Tensor, Y_tensor: torch.Tensor
    ) -> torch.Tensor:
        """
        Computes log-likelihood log p(y|x) on the scaled data space.
        This is now an internal method, so it does not need to handle scaling
        of the output.
        """
        if not self.gmm_model:
            raise RuntimeError(
                "Global GMM has not been fitted. This should not happen after training."
            )

        y_np = Y_tensor.cpu().numpy()
        if y_np.ndim == 1:
            y_np = y_np.reshape(-1, 1)
        densities_np = np.exp(self.gmm_model._estimate_log_prob(y_np))
        component_densities = torch.tensor(
            densities_np, dtype=torch.float32, device=self.device
        )

        rule_probs, _ = self.rules_model(X_tensor)

        log_likelihood, _ = self.expert_model(
            rule_probs, component_densities, mean_reduce=False
        )

        return log_likelihood

    def _fit_global_gmm(self):
        """Fits the global GMM on the target data."""
        Y_np = self.Y_scaled.cpu().numpy()
        if Y_np.ndim == 1:
            Y_np = Y_np.reshape(-1, 1)

        if self.config.verbose:
            print(
                f"Fitting global GMM with up to {self.config.n_gmm_components} components..."
            )

        component_scoring = self.config.component_scoring
        n_components = self.config.n_gmm_components

        if component_scoring is not None:
            best_score = np.inf
            best_n = n_components

            for i in range(5, self.config.n_gmm_components + 1):
                trial_model = GaussianMixture(
                    n_components=i,
                    reg_covar=self.config.gmm_reg_covar,
                    max_iter=self.config.gmm_max_iter,
                    n_init=3,
                    random_state=self.config.seed,
                ).fit(Y_np)

                score = 0
                if component_scoring.lower() == "bic":
                    score = trial_model.bic(Y_np)
                elif component_scoring.lower() == "aic":
                    score = trial_model.aic(Y_np)
                else:
                    raise ValueError(
                        f"Unknown component scoring method: {component_scoring}"
                    )
                if score < best_score:
                    best_score = score
                    best_n = i
            n_components = best_n

        if self.config.verbose:
            print(f"Using {n_components} global GMM components.")
        self.gmm_model = GaussianMixture(
            n_components=n_components + self.config.n_gmm_extra_components,
            reg_covar=self.config.gmm_reg_covar,
            max_iter=self.config.gmm_max_iter,
            n_init=3,
            random_state=self.config.seed,
        ).fit(Y_np)

        if not self.gmm_model.converged_:
            print("Warning: Global GMM did not converge.")

    def _calculate_loss(self, X_batch, densities_batch, return_dict=False):
        """
        Calculates the total loss for a batch, including all penalties.
        """
        # Get probabilities and L1 penalty from the rules model
        rule_probs, rule_l1_loss = self.rules_model(X_batch)

        # Get log-likelihood and L1 penalty from the expert model
        log_likelihoods, gmm_l1_penalty = self.expert_model(
            rule_probs, densities_batch, mean_reduce=False
        )

        nll_loss = -torch.mean(log_likelihoods)

        partition_loss = torch.tensor(0.0, device=self.device)
        if self.config.partition_weight > 0:
            # Partition loss applies only to the main, interpretable components
            interpretable_probs = rule_probs[:, : self.config.n_mixture_components]
            partition_loss = torch.mean(torch.sum(interpretable_probs**2, dim=1))

        total_loss = (
            nll_loss
            + self.config.gmm_remix_l1_weight * gmm_l1_penalty
            + self.config.and_layer_entropy * rule_l1_loss
            - self.config.partition_weight * partition_loss
        )

        if return_dict:
            return {
                "total": total_loss,
                "nll": nll_loss,
                "gmm_l1": gmm_l1_penalty,
                "rule_entropy": rule_l1_loss,
                "partition": partition_loss,
            }

        return total_loss

    def _train_step(self, X_batch, densities_batch):
        """Logic for a single joint optimization step."""
        self.optimizer.zero_grad()
        total_loss = self._calculate_loss(X_batch, densities_batch)
        total_loss.backward()
        self.optimizer.step()

    def save(self, path: str):
        """Saves the complete state of the model."""
        if not self.rules_model or not self.expert_model:
            raise RuntimeError("Cannot save an untrained model.")

        state = {
            "config": self.config,
            "rules_model_state_dict": self.rules_model.state_dict(),
            "expert_model_state_dict": self.expert_model.state_dict(),
            "preprocessor": self.preprocessor,
            "metrics": self.metrics,
            "history": self.history,
            "feature_names": self.feature_names,
            "gmm_model": self.gmm_model,
            "disabled_components": self.disabled_components,
        }
        torch.save(state, path)

    @classmethod
    def load(cls, path: str, device: str = "cpu") -> "MixtureModel":
        """Loads a model from a file and restores its full state."""
        state = torch.load(path, map_location=device, weights_only=False)
        config = state["config"]

        torch.manual_seed(config.seed)
        np.random.seed(config.seed)

        model = RemixMixtureModel(config)
        model.preprocessor = state["preprocessor"]
        model.feature_names = state["feature_names"]
        model.device = torch.device(device)
        model.gmm_model = state["gmm_model"]

        dummy_x_shape = len(model.feature_names)
        dummy_y_shape = model.preprocessor.scaler_y.n_features_in_
        model.X_scaled = torch.zeros(
            (2, dummy_x_shape), device=model.device, dtype=torch.float32
        )
        model.Y_scaled = torch.zeros(
            (2, dummy_y_shape), device=model.device, dtype=torch.float32
        )

        model._setup_rules_model()
        model._initialize_experts()

        model.rules_model.load_state_dict(state["rules_model_state_dict"])
        model.expert_model.load_state_dict(state["expert_model_state_dict"])

        model.metrics = state["metrics"]
        model.history = state["history"]

        model.disabled_components = state.get(
            "disabled_components", [False] * config.n_mixture_components
        )
        if model.rules_model:
            for i, is_disabled in enumerate(model.disabled_components):
                if i < len(model.rules_model.rules):
                    model.rules_model.rules[i].disabled = is_disabled

        model.rules_model.to(model.device)
        model.expert_model.to(model.device)
        model.rules_model.eval()
        model.expert_model.eval()

        return model
