import torch
import torch.nn as nn
import numpy as np
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
import time

from emm.utils.data_preprocessor import DataPreprocessor
from emm.utils.train_utils import (
    TemperatureAnnealer,
    create_auto_temperature_annealer,
)
from .neural_rules import SimpleMixtureRules
from .cutpoint_initializer import get_initializer
from .training_config import TrainingConfig
from .plotting import ModelPlotter, TrainingPlotter


@dataclass
class TrainingSnapshot:
    """
    A dataclass to store a snapshot of the model's state and performance
    at a specific point during training for history plotting.
    """

    step: int
    mixture_rules_state: Dict[str, Any]
    expert_model_state: Dict[str, Any]
    disabled_components: List[bool]
    current_temp: float
    losses: Dict[str, float]

    def __post_init__(self):
        """
        Ensures that all tensor data is detached from the computation graph,
        moved to the CPU, and is an independent copy.
        """
        self.mixture_rules_state = {
            k: v.cpu().detach().clone() for k, v in self.mixture_rules_state.items()
        }
        self.expert_model_state = {
            k: v.cpu().detach().clone() for k, v in self.expert_model_state.items()
        }
        self.losses = {
            k: v.item() if isinstance(v, torch.Tensor) else v
            for k, v in self.losses.items()
        }


class MixtureModel(ABC):
    """
    Abstract base class for an interpretable rule-based mixture of experts model.
    """

    def __init__(self, config: TrainingConfig):
        """
        Args:
            config: A TrainingConfig object.
        """
        self.config = config
        self.preprocessor = DataPreprocessor(
            scaler_x_type=config.scaler_x_type, scaler_y_type=config.scaler_y_type
        )
        self.rules_model: Optional[SimpleMixtureRules] = None
        self.expert_model: Optional[nn.Module] = None
        self.history: List[TrainingSnapshot] = []
        self.metrics: Dict[str, Any] = {}
        self.search_history: Optional[List[Dict[str, Any]]] = None
        self.device = torch.device(config.device)
        self.X_original: Optional[np.ndarray] = None
        self.Y_original: Optional[np.ndarray] = None
        self.X_scaled: Optional[torch.Tensor] = None
        self.Y_scaled: Optional[torch.Tensor] = None
        self.feature_names: Optional[List[str]] = None
        self.disabled_components: Optional[List[bool]] = None

    @abstractmethod
    def _fit_single_model(
        self, X: np.ndarray, Y: np.ndarray, feature_names: Optional[List[str]] = None
    ):
        """
        Main training method for a single model configuration. Subclasses must
        implement their specific training loop here. This method should be
        self-contained and perform its own data setup.
        """
        pass

    def fit(
        self, X: np.ndarray, Y: np.ndarray, feature_names: Optional[List[str]] = None
    ):
        """
        Main entry point for training.

        If `config.use_model_finder` is True, this method orchestrates a search
        over different numbers of components to find the best model based on BIC.
        The state from the best model is then transferred to this instance.

        If `config.use_model_finder` is False, it proceeds with training a single
        model using the provided configuration.
        """
        if not self.config.use_model_finder:
            self._fit_single_model(X, Y, feature_names)
            return

        best_model, search_history = self._find_best_model(X, Y, feature_names)

        if self.config.model_finder_return_history:
            self.search_history = search_history

        if self.config.verbose:
            print("--- Transferring state from best model to current instance ---")

        self.config = best_model.config
        self.preprocessor = best_model.preprocessor
        self.rules_model = best_model.rules_model
        self.expert_model = best_model.expert_model
        self.history = best_model.history
        self.metrics = best_model.metrics
        self.X_original = best_model.X_original
        self.Y_original = best_model.Y_original
        self.X_scaled = best_model.X_scaled
        self.Y_scaled = best_model.Y_scaled
        self.feature_names = best_model.feature_names
        self.disabled_components = best_model.disabled_components
        if hasattr(best_model, "gmm_model"):
            self.gmm_model = best_model.gmm_model

    def _find_best_model(
        self, X: np.ndarray, Y: np.ndarray, feature_names: Optional[List[str]] = None
    ) -> Tuple["MixtureModel", List[Dict[str, Any]]]:
        """
        Trains multiple models and returns the best-performing instance.
        """
        if not self.config.model_finder_component_range:
            raise ValueError(
                "model_finder_component_range must be set in TrainingConfig when use_model_finder is True."
            )

        best_model_instance: Optional[MixtureModel] = None
        best_bic_score = float("inf")
        search_history_list: List[Dict[str, Any]] = []
        verbose = self.config.verbose

        for n_components in self.config.model_finder_component_range:
            if verbose:
                print(
                    f"\n{'='*20} Training {self.__class__.__name__} with {n_components} components {'='*20}"
                )

            current_config = self.config.copy()
            current_config.n_mixture_components = n_components

            model_instance = self.__class__(config=current_config)

            start_time = time.time()
            try:
                model_instance._fit_single_model(X, Y, feature_names=feature_names)
            except Exception as e:
                print(
                    f"ERROR: Training failed for {n_components} components with exception: {e}"
                )
                if verbose:
                    import traceback

                    traceback.print_exc()
                continue

            elapsed_time = time.time() - start_time
            if verbose:
                print(f"Finished training in {elapsed_time:.2f}s.")

            metrics = model_instance.get_metrics()
            current_bic_score = metrics.get("bic")

            if current_bic_score is None:
                print(
                    f"Warning: BIC score not found for n_components = {n_components}. Skipping."
                )
                continue

            if verbose:
                print(
                    f"Model with {n_components} components: BIC = {current_bic_score:.4f}"
                )

            search_history_list.append(
                {
                    "n_components": n_components,
                    "bic_score": current_bic_score,
                    "training_time": elapsed_time,
                    "metrics": metrics,
                }
            )

            if current_bic_score < best_bic_score:
                best_bic_score = current_bic_score
                best_model_instance = model_instance
                if verbose:
                    print(
                        f"*** New best model found with {n_components} components (BIC: {best_bic_score:.4f}) ***"
                    )

        if best_model_instance is None:
            raise ValueError(
                "Model search failed. No model was trained successfully. Check config and data."
            )

        if verbose:
            best_n_comp = best_model_instance.config.n_mixture_components
            print(f"\n{'='*20} Search Complete {'='*20}")
            print(
                f"Best model found has {best_n_comp} components with BIC: {best_bic_score:.4f}"
            )

        return best_model_instance, search_history_list

    @abstractmethod
    def _initialize_experts(self):
        """Abstract method for initializing the expert model."""
        pass

    @abstractmethod
    def _initialize_optimizers(self):
        """Abstract method for initializing the optimizers."""
        pass

    @abstractmethod
    def _get_log_likelihood(
        self, X_tensor: torch.Tensor, Y_tensor: torch.Tensor
    ) -> torch.Tensor:
        """
        Abstract method for computing the log-likelihood for a batch of data on the
        SCALED data space. Subclasses must implement this logic. The base class
        will handle the conversion to the original data space.

        Args:
            X_tensor: A torch.Tensor of shape (n_samples, n_features) with scaled data.
            Y_tensor: A torch.Tensor of shape (n_samples, y_dim) with scaled data.

        Returns:
            A torch.Tensor of shape (n_samples,) containing the log-likelihoods.
        """
        pass

    @abstractmethod
    def _reorder_experts(self, sort_idx: torch.Tensor):
        """
        Abstract method for reordering the expert models based on rule sorting.
        Subclasses must implement the specific logic for their expert model.
        """
        pass

    @abstractmethod
    def get_expert_densities(self, y_range: np.ndarray) -> np.ndarray:
        """
        Computes the probability density p_j(y) for each expert j over a given
        range of y values in the original data space.

        Args:
            y_range: A numpy array of shape (n_points, y_dim) at which to
                     evaluate the densities.

        Returns:
            A numpy array of shape (n_points, n_components) containing the
            densities for each component.
        """
        pass

    def sort_components_by_mean(self):
        """
        Sorts the mixture components (rules and corresponding experts) in-place
        based on the mean of the target variable Y for each rule.
        """
        if not self.rules_model or self.X_scaled is None or self.Y_scaled is None:
            raise RuntimeError("Model must be trained before components can be sorted.")

        if self.config.verbose:
            print("\n--- Sorting components by mean of target variable ---")

        sort_idx = self.rules_model.sort_rules(
            self.X_scaled,
            torch.tensor(
                self.Y_original, dtype=self.Y_scaled.dtype, device=self.device
            ),
        )

        if self.disabled_components:
            self.disabled_components = [self.disabled_components[i] for i in sort_idx]
            for i, rule in enumerate(self.rules_model.rules):
                rule.disabled = self.disabled_components[i]

        self._reorder_experts(sort_idx)

        if self.config.verbose:
            print(f"New component order: {[i.item() + 1 for i in sort_idx]}")
            print("--- Sorting complete ---")

    def log_likelihood(self, X: np.ndarray, Y: np.ndarray) -> np.ndarray:
        """
        Computes the log-likelihood log p(y|x) for new data, correctly scaled
        to the original data space.

        Args:
            X: A numpy array of shape (n_samples, n_features) with original data.
            Y: A numpy array of shape (n_samples, y_dim) with original data.

        Returns:
            A numpy array of shape (n_samples,) containing the log-likelihoods.
        """
        if not self.rules_model or not self.expert_model:
            raise RuntimeError("Model has not been trained yet. Call fit() first.")

        self.rules_model.eval()
        self.expert_model.eval()
        self.rules_model.to(self.device)
        self.expert_model.to(self.device)

        X_tensor, Y_tensor, _, _ = self.preprocessor.transform(X, Y)
        X_tensor, Y_tensor = X_tensor.to(self.device), Y_tensor.to(self.device)

        with torch.no_grad():
            # Get the log-likelihood on the scaled data space
            log_likelihood_scaled = self._get_log_likelihood(X_tensor, Y_tensor)

            # Apply the Jacobian correction to convert to original data space.
            # log p(y) = log p_scaled(y_s) + log |det(J)| where J is the Jacobian
            # of the transformation y_s = f(y). For a standard scaler y_s = (y-mu)/sigma,
            # the determinant of the Jacobian is 1/sigma for each dimension.
            # So, log p(y) = log p_scaled(y_s) - sum(log(sigma_i)).
            log_jacobian_determinant = 0.0
            scaler_y = self.preprocessor.scaler_y
            if hasattr(scaler_y, "scale_") and scaler_y.scale_ is not None:
                # The log determinant of the Jacobian for y_s = (y-mu)/sigma is -sum(log(sigma_i))
                log_jacobian_determinant = -np.sum(np.log(scaler_y.scale_ + 1e-9))

            # log p(y) = log p_s(y_s) + log_jacobian_determinant
            log_likelihood_orig = log_likelihood_scaled + log_jacobian_determinant

        return log_likelihood_orig.cpu().numpy()

    def get_nll(self, X: np.ndarray, Y: np.ndarray) -> float:
        """
        Computes the Negative Log-Likelihood (NLL) for new data,
        correctly scaled to the original data space. This is a convenience
        method that returns a single scalar value.

        Args:
            X: A numpy array of shape (n_samples, n_features) with original data.
            Y: A numpy array of shape (n_samples, y_dim) with original data.

        Returns:
            A float representing the mean NLL.
        """
        log_probs = self.log_likelihood(X, Y)
        return -np.mean(log_probs)

    def get_responsibilities(self, X: np.ndarray) -> np.ndarray:
        """Returns the gating probabilities s_j(x) for each component."""
        if not self.rules_model:
            raise RuntimeError("Model has not been trained yet. Call fit() first.")

        self.rules_model.eval()
        self.rules_model.to(self.device)

        # Create a dummy Y for transformation
        dummy_y_shape = (
            X.shape[0],
            self.Y_original.shape[1] if self.Y_original.ndim > 1 else 1,
        )
        X_tensor, _, _, _ = self.preprocessor.transform(X, np.zeros(dummy_y_shape))
        X_tensor = X_tensor.to(self.device)

        with torch.no_grad():
            responsibilities, _ = self.rules_model(X_tensor)

        return responsibilities.cpu().numpy()

    def get_labels(self, X):
        """
        Returns the component labels for each data point in X.
        """
        responsibilities = self.get_responsibilities(X)
        return np.argmax(responsibilities, axis=1)

    def get_activations(self, X: np.ndarray) -> np.ndarray:
        """Returns the activations a_j(x) for each component."""
        if not self.rules_model:
            raise RuntimeError("Model has not been trained yet. Call fit() first.")

        self.rules_model.eval()
        self.rules_model.to(self.device)

        dummy_y_shape = (
            X.shape[0],
            self.Y_original.shape[1] if self.Y_original.ndim > 1 else 1,
        )
        X_tensor, _, _, _ = self.preprocessor.transform(X, np.zeros(dummy_y_shape))
        X_tensor = X_tensor.to(self.device)

        with torch.no_grad():
            activations = self.rules_model.forward_raw(X_tensor)

        return activations.cpu().numpy()

    def get_rules(
        self,
        output_format: str = "text",
        output_dir="plots",
        responsibility_threshold=0.1,
        activation_threshold=0.001,
        weight_threshold=0.1,
        assign_max_resp=False,
        rules_to_plot: Optional[List[int]] = None,
        show_proportional_dist: bool = True,
        show_density_dist: bool = False,
        show_population_histogram: bool = False,
        scale_densities_by_weight: bool = False,
        y_name: str = "Y",
        **kwargs,
    ) -> str:
        """
        Extracts and displays rules in a structured, comparable format.

        This is the new, recommended method for inspecting rules. It uses the
        centralized MixturePlotter to generate a text table, an HTML file, or a plot,
        ensuring that all data transformations are handled correctly for visualization
        in the original data space.

        Args:
            output_format (str, optional): 'text', 'html', 'plot', or 'plot_condensed'. Defaults to 'text'.
            rules_to_plot (List[int], optional): For combined/condensed plots, a list of rule indices to plot.
            show_proportional_dist (bool): For 'plot_combined', show proportional distribution (hist/scatter).
                                           This is ignored if `show_population_histogram` is True.
            show_density_dist (bool): For 'plot_combined', show expert density distribution.
            show_population_histogram (bool): For 1D 'plot_combined', shows a gray histogram of the
                                              entire population in the background, like a standard GMM plot.
            scale_densities_by_weight (bool): For 1D 'plot_combined' with densities, scales each component's
                                              density by its average responsibility over the dataset.
            **kwargs: Additional arguments passed to the plotter, such as `filepath`,
                      `weight_threshold`, `show_densities`, etc.

        Returns:
            str: The formatted string (for 'text'/'html') or a confirmation message.
        """
        if not self.rules_model:
            raise RuntimeError("Model has not been trained or set up properly.")

        plotter = ModelPlotter(output_dir=output_dir, y_name=y_name)
        return plotter.plot_rules_summary(
            model=self,
            output_format=output_format,
            activation_threshold=activation_threshold,
            weight_threshold=weight_threshold,
            responsibility_threshold=responsibility_threshold,
            assign_max_resp=assign_max_resp,
            rules_to_plot=rules_to_plot,
            show_proportional_dist=show_proportional_dist,
            show_density_dist=show_density_dist,
            show_population_histogram=show_population_histogram,
            scale_densities_by_weight=scale_densities_by_weight,
            **kwargs,
        )

    def plot_expert_densities(self, output_dir="plots", **kwargs):
        """
        Plots the density of each expert component (e.g., GMM or Flow).

        This method provides a high-level interface to visualize the learned
        distributions of the expert models on the original data scale.

        Args:
            **kwargs: Additional arguments passed to the plotter, such as `save_name`,
                      `n_points`, etc.
        """
        if not self.expert_model:
            raise RuntimeError("Model has not been trained yet. Call fit() first.")

        plotter = ModelPlotter(output_dir=output_dir)
        plotter.plot_expert_densities(model=self, **kwargs)

    def plot_training_animation(self, output_dir="plots", **kwargs):
        """
        Generates an animation of the training process.
        """
        if not self.history:
            print("Warning: No training history found. Cannot create animation.")
            return
        plotter = TrainingPlotter(output_dir=output_dir)
        plotter.plot_training_animation_2d(model=self, **kwargs)

    def plot_training_snapshots(
        self, snapshot_steps: List[int], output_dir="plots", **kwargs
    ):
        """
        Generates a static plot of training snapshots.
        """
        if not self.history:
            print("Warning: No training history found. Cannot create plots.")
            return
        plotter = TrainingPlotter(output_dir=output_dir)
        plotter.plot_training_snapshots_2d(
            model=self, snapshot_steps=snapshot_steps, **kwargs
        )

    def plot_gating_heatmap(self, output_dir="plots", **kwargs):
        """
        Generates a heatmap visualization of the gating activations or responsibilities.
        """
        if not self.history:
            print("Warning: No training history found. Cannot create plots.")
            return
        plotter = ModelPlotter(output_dir=output_dir)
        plotter.plot_gating_heatmap(model=self, **kwargs)

    def get_metrics(self) -> dict:
        """Returns a dictionary of key metrics (NLL, BIC, AIC, etc.)."""
        return self.metrics

    def get_rule_complexity_metrics(self) -> Dict[str, float]:
        """
        Calculates and returns metrics related to rule complexity.

        Returns:
            A dictionary with 'avg_rule_complexity' and 'total_conditions'.
        """
        if not self.rules_model:
            return {"avg_rule_complexity": 0.0, "total_conditions": 0.0}

        total_conditions = 0
        active_rules_count = 0
        for rule in self.rules_model.rules:
            if not rule.disabled:
                utilized_features = rule.get_utilized_features()
                total_conditions += len(utilized_features)
                active_rules_count += 1

        avg_rule_complexity = (
            total_conditions / active_rules_count if active_rules_count > 0 else 0.0
        )

        return {
            "rule_complexity": avg_rule_complexity,
            "total_conditions": float(total_conditions),
        }

    @abstractmethod
    def save(self, path: str):
        """Saves the trained model state and configuration."""
        pass

    @classmethod
    @abstractmethod
    def load(cls, path: str, device: str = "cpu") -> "MixtureModel":
        """
        Loads a model from a file.
        """
        pass

    def _setup_rules_model(self):
        """Initializes just the SimpleMixtureRules model."""
        initializer = get_initializer(
            X=self.X_scaled,
            n_components=self.config.n_mixture_components,
            **(self.config.initializer_config or {}),
        )
        self.rules_model = SimpleMixtureRules(
            X=self.X_scaled,
            n_components=self.config.n_mixture_components,
            temperature=self.config.temperature,
            initializer=initializer,
            use_background_component=self.config.use_background_component,
            background_epsilon=self.config.background_epsilon,
        ).to(self.device)

    def _setup_training(self, X, Y, feature_names):
        """Prepares data, device, seeds, and initializes models."""
        torch.manual_seed(self.config.seed)
        np.random.seed(self.config.seed)

        self.X_original, self.Y_original = X, Y
        (
            self.X_scaled,
            self.Y_scaled,
            _,
            _,
        ) = self.preprocessor.fit_transform(X, Y)
        self.X_scaled = self.X_scaled.to(self.device)
        self.Y_scaled = self.Y_scaled.to(self.device)

        self.feature_names = (
            feature_names
            if feature_names is not None
            else [f"X{i}" for i in range(X.shape[1])]
        )

        self._setup_rules_model()

        self.disabled_components = [False] * self.config.n_mixture_components

        if self.config.temp_anneal == "auto":
            self.temp_annealer = create_auto_temperature_annealer(
                self.config.temperature
            )
        elif isinstance(self.config.temp_anneal, dict):
            self.temp_annealer = TemperatureAnnealer.from_dict(self.config.temp_anneal)
        elif isinstance(self.config.temp_anneal, TemperatureAnnealer):
            self.temp_annealer = self.config.temp_anneal
        else:
            self.temp_annealer = None

    def _anneal_parameters(self, step: int):
        """Updates temperature based on the annealing schedule."""
        if self.temp_annealer is not None:
            new_temp = self.temp_annealer.get_temperature(
                step, self.config.component_train_epochs
            )
            self.rules_model.set_temperature(new_temp)

    def _check_and_update_component_status(self, step: int):
        """Checks component responsibilities and updates the disabled status."""
        if step > 0 and step % self.config.check_responsibility_every == 0:
            with torch.no_grad():
                full_rule_probs, _ = self.rules_model(self.X_scaled)
                interpretable_probs = full_rule_probs[
                    :, : self.config.n_mixture_components
                ]
                mean_responsibilities = torch.mean(interpretable_probs, dim=0)

                for i, resp in enumerate(mean_responsibilities):
                    if resp < self.config.min_responsibility_threshold:
                        if not self.disabled_components[i] and self.config.verbose:
                            print(
                                f"Step {step}: Disabling component {i+1} with responsibility {resp:.6f}"
                            )
                        self.disabled_components[i] = True
                        self.rules_model.rules[i].disabled = True

    def _prune_components(self) -> bool:
        """
        Post-hoc pruning of components that claim very few data points.
        Returns True if any component was pruned.
        """
        if self.config.pruning_threshold <= 0.0 or not self.rules_model:
            return False

        pruned = False
        with torch.no_grad():
            full_rule_probs, _ = self.rules_model(self.X_scaled)
            interpretable_probs = full_rule_probs[:, : self.config.n_mixture_components]
            hard_assignments = torch.argmax(interpretable_probs, dim=1)

            for j in range(self.config.n_mixture_components):
                n_claimed = torch.sum(hard_assignments == j).item()
                if (
                    n_claimed < self.config.pruning_threshold
                    and not self.disabled_components[j]
                ):
                    if self.config.verbose:
                        print(
                            f"Post-hoc pruning component {j+1} with {n_claimed} claimed points."
                        )
                    self.disabled_components[j] = True
                    self.rules_model.rules[j].disabled = True
                    pruned = True
        return pruned

    def _merge_components(self) -> bool:
        """
        Placeholder for merging logic. Subclasses should implement this.
        """
        if self.config.merge_components and self.config.verbose:
            print("Note: Merging not implemented for this model class.")
        return False

    def _settle_components(self):
        """
        Placeholder for settling logic. Subclasses should implement this.
        """
        if self.config.merge_settle_epochs > 0 and self.config.verbose:
            print("Note: Settling not implemented for this model class.")
        pass

    def _calculate_final_metrics(self, final_nll: float):
        """
        Calculates final metrics like AIC and BIC.
        """
        if not self.rules_model or not self.expert_model:
            print("Warning: Cannot calculate metrics. Models not initialized.")
            return

        n_rule_params = self.rules_model.count_active_parameters()
        total_parameters = n_rule_params
        n_samples = self.X_scaled.shape[0]

        if not np.isnan(final_nll):
            aic = 2 * total_parameters + 2 * n_samples * final_nll
            bic = total_parameters * np.log(n_samples) + 2 * n_samples * final_nll
        else:
            aic, bic = np.nan, np.nan

        self.metrics["final_nll"] = final_nll
        self.metrics["aic"] = aic
        self.metrics["bic"] = bic
        self.metrics["total_parameters"] = total_parameters
        if self.config.verbose:
            print("\n--- Final Metrics ---")
            print(f"Final NLL: {final_nll:.4f}")
            print(f"Total Active Parameters: {total_parameters}")
            print(f"AIC: {aic:.2f}")
            print(f"BIC: {bic:.2f}")
            print("---------------------\n")

    def _log_progress(self, step: int, losses: dict):
        """
        Standardized logging of training progress and records a history snapshot.
        """
        if (
            self.config.verbose
            and step > 0
            and step % (max(1, self.config.component_train_epochs // 10)) == 0
        ):
            print(f"\nStep {step}/{self.config.component_train_epochs}")
            loss_str = ", ".join([f"{name}: {val:.4f}" for name, val in losses.items()])
            print(f"Losses: {loss_str}")

        # Record history snapshot
        if (
            self.config.record_history_every > 0
            and step % self.config.record_history_every == 0
        ):
            self._record_history_snapshot(step, losses)

    def _record_history_snapshot(self, step: int, losses: Dict[str, float]):
        """
        Creates a TrainingSnapshot and appends it to the model's history.
        """
        if not self.rules_model or not self.expert_model:
            return

        snapshot = TrainingSnapshot(
            step=step,
            mixture_rules_state=self.rules_model.state_dict(),
            expert_model_state=self.expert_model.state_dict(),
            disabled_components=list(self.disabled_components),
            current_temp=self.rules_model.rules[0].discretizer.temperature,
            losses=losses,
        )
        self.history.append(snapshot)
