from typing import Optional

import numpy as np

from mpf_py._mpf_py import MPF as _MPF
from mpf_py._mpf_py import FitResult
from mpf_py._mpf_py import TreeGrid as _TreeGrid
from mpf_py.wrapper import PythonWrapperClassBase

REFINEMENT_STRATEGY_MAP = {
    "l2": 1,
    "huber": 2,
}

SPLIT_STRATEGY_MAP = {
    "random": 1,
    "best_split": 2,
    "top_k": 3,
}


class TreeGrid(PythonWrapperClassBase, RustClass=_TreeGrid):
    def __init__(self, tg):
        super().__init__()
        self._rust_instance = tg

    @classmethod
    def fit(
        cls,
        x: np.typing.NDArray[np.float64],
        y: np.typing.NDArray[np.float64],
        n_iter: int,
        split_try: int,
        colsample_bytree: float,
        complexity_penalty: float = 0.0,
        seed: int = 42,
    ) -> tuple["TreeGrid", "FitResult"]:
        # Ensure arrays are contiguous before passing to Rust
        x = np.ascontiguousarray(x, dtype=np.float64)
        y = np.ascontiguousarray(y, dtype=np.float64)
        tg, fr = _TreeGrid.fit(
            x,
            y,
            n_iter,
            split_try,
            colsample_bytree,
            complexity_penalty,
            seed,
        )
        return cls(tg), fr

    def predict(
        self, x: np.typing.NDArray[np.float64]
    ) -> np.typing.NDArray[np.float64]:
        # Ensure array is contiguous before passing to Rust
        x = np.ascontiguousarray(x, dtype=np.float64)
        return self._rust_instance.predict(x)

    def get_component(self, axis: int):
        intervals = self.intervals[axis]
        values = self.mean_factor[axis]

        return [*zip(intervals, values)]

    def plot_components(
        self, individual_plots: bool = False, axis: Optional[int] = None
    ):
        try:
            import matplotlib.pyplot as plt
        except ImportError:
            raise ImportError(
                "Matplotlib is required to use the 'plot' function. "
                "Please install it using: pip install matplotlib"
            )

        n_components = len(self.intervals)
        # Use a colormap to generate a unique color for each component
        colors = plt.cm.viridis(np.linspace(0, 1, n_components))

        if axis is not None:
            if not 0 <= axis < n_components:
                raise ValueError(f"axis must be between 0 and {n_components - 1}")
            axes_to_plot = [(axis, (self.intervals[axis], self.mean_factor[axis]))]
        else:
            axes_to_plot = enumerate(zip(self.intervals, self.mean_factor))

        # Create single figure for combined plot upfront
        if not individual_plots:
            plt.figure(figsize=(10, 6))

        for axis_idx, (intervals, values) in axes_to_plot:
            if individual_plots:
                plt.figure(figsize=(10, 6))  # Create new figure per axis

            color = colors[axis_idx]

            # Create connected step function instead of separate horizontal lines
            x_points = []
            y_points = []

            # Use intervals directly without sorting
            for (x_start, x_end), y in zip(intervals, values):
                # Skip infinite intervals for plotting
                if x_start == float("-inf") or x_end == float("inf"):
                    continue

                # Add points for step function
                x_points.extend([x_start, x_end])
                y_points.extend([y, y])

            if x_points:  # Only plot if we have valid points
                plt.step(
                    x_points,
                    y_points,
                    where="post",
                    lw=1,
                    color=color,
                    label=f"Axis {axis_idx}",
                )

            if individual_plots:
                plt.xlabel("X-axis")
                plt.ylabel("Value")
                plt.title(
                    f"TreeGrid Component for Axis {axis_idx}, Scaling: {self.scaling}"
                )
                plt.grid(True)
                plt.legend()
                plt.show()  # Show individual plots immediately

        if not individual_plots:
            # Final setup for combined plot
            plt.xlabel("X-axis")
            plt.ylabel("Value")
            plt.title(f"TreeGrid One-Dimensional Components, Scaling: {self.scaling}")
            plt.grid(True)
            plt.legend()
            plt.show()


class MPF(PythonWrapperClassBase, RustClass=_MPF):
    def __init__(self, mpf_boosted):
        super().__init__()
        self._rust_instance = mpf_boosted

    @classmethod
    def fit(
        cls,
        x: np.typing.NDArray[np.float64],
        y: np.typing.NDArray[np.float64],
        epochs: int,
        decay: float = 1.0,
        n_trees: int = 10,
        n_iter: int = 10,
        split_try: int = 10,
        colsample_bytree: float = 0.8,
        alpha: float = 0.0,
        complexity_penalty: float = 0.0,
        min_split_loss: float = 0.0,
        min_interval_samples: int = 1,
        refinement_strategy: str = "l2",
        prior_sample_size: float = 0.0,
        update_clamp: float = float("inf"),
        tilt_tau: float = 0.01,
        tilt_rho: float = 0.0,
        split_strategy: str = "random",
        top_k: int = 10,
        must_fill_all_k: bool = True,
        similarity_threshold: float = 0.0,
        bagged: bool = False,
        seed: int = 42,
        verbosity: int = 1,
        visualdb: Optional[str] = None,
    ) -> tuple["MPF", "FitResult"]:
        """Fit an MPF boosted model to the training data.

        Parameters
        ----------
        x : np.ndarray
            Training input samples.
        y : np.ndarray
            Target values.
        epochs : int
            Number of boosting epochs.
        n_trees : int
            Number of trees per epoch.
        n_iter : int
            Number of iterations for fitting each TreeGrid.
        split_try : int
            Number of split points to try for each feature.
        colsample_bytree : float
            Subsample ratio of columns when constructing each tree.
        decay : float, default=1.0
            The decay rate for the number of splits in each epoch.
        alpha : float, default=0.0
            Regularization parameter.
        complexity_penalty : float, default=0.0
            Complexity penalty (lambda) for adaptive merge bonus. This is BIC-inspired and scale-invariant.
            Larger values encourage simpler models. Typical values: 0.5-2.0. Default: 0.0 (no complexity penalty).
        min_split_loss : float, default=0.0
            Minimum loss reduction required to split a node.
        min_interval_samples : int, default=1
            Minimum number of samples in an interval.
        refinement_strategy : str, default="l2"
            Strategy for refining grid values ("l2", "huber").
        prior_sample_size : float, default=0.0
            Prior sample size (tau_0) for parent anchoring. Interpreted as "how many samples worth of confidence
            we have that children should equal their parent". With tau_0=30, a child with 10 samples will be
            heavily shrunk toward the parent, while a child with 100 samples will mostly trust its own data.
            Default: 0.0 (no anchoring). Typical values: 10-50.
        update_clamp : float, default=float('inf')
            Maximum allowed update magnitude for refinement (infinity = no clamping).
        tilt_tau : float, default=0.01
            Two-tensor L2 coupling between u_+ and u_- (objective τ). Controls the strength
            of the quadratic penalty on the difference between positive and negative components.
        tilt_rho : float, default=0.0
            Two-tensor L1 coupling on (u_+ - u_-) (objective ρ). When > 0, can drive the tilt
            exactly to zero on many sides, yielding pure "backbone-only" updates.
        split_strategy : str, default="random"
            Strategy for selecting split points ("random", "best_split", "top_k").
        similarity_threshold : float, default=0.0
            Threshold for merging similar components.
        bagged : bool, default=False
            Whether to use bagging.
        seed : int, default=42
            Random seed for reproducibility.
        verbosity : int, default=1
            Verbosity level for the rust backend.
        top_k : int, default=10
            When split_strategy is "top_k", number of top candidate splits to consider.
        must_fill_all_k : bool, default=True
            When split_strategy is "top_k", require all K splits to be filled.
        visualdb : str, optional, default=None
            Path to SQLite database for saving split events and run information.
            If specified, split events will be recorded during fitting and saved to the database.
            Requires the evo-logging feature to be enabled. If the feature is disabled,
            a warning will be logged.

        Returns
        -------
        tuple[MPF, FitResult]
            The fitted MPF model and fit result.
        """
        # Ensure arrays are contiguous before passing to Rust
        x = np.ascontiguousarray(x, dtype=np.float64)
        y = np.ascontiguousarray(y, dtype=np.float64)
        refinement_strategy = REFINEMENT_STRATEGY_MAP[refinement_strategy]
        split_strategy = SPLIT_STRATEGY_MAP[split_strategy]
        mpf_boosted, fr = _MPF.fit(
            x,
            y,
            epochs,
            decay,
            n_trees,
            n_iter,
            split_try,
            colsample_bytree,
            alpha,
            complexity_penalty,
            min_split_loss,
            min_interval_samples,
            refinement_strategy,
            prior_sample_size,
            update_clamp,
            tilt_tau,
            tilt_rho,
            split_strategy,
            top_k,
            must_fill_all_k,
            similarity_threshold,
            bagged,
            seed,
            verbosity,
            visualdb,
        )
        instance = cls(mpf_boosted)
        instance._tree_grid_families_lst = [
            [TreeGrid(tg) for tg in tf.tree_grids]
            for tf in mpf_boosted.tree_grid_families
        ]
        return instance, fr

    def predict(
        self, x: np.typing.NDArray[np.float64]
    ) -> np.typing.NDArray[np.float64]:
        # Ensure array is contiguous before passing to Rust
        x = np.ascontiguousarray(x, dtype=np.float64)
        return self._rust_instance.predict(x)

    def compute_partial_dependence_function(
        self,
        fixed_indices: list[int],
        fixed_values: np.typing.NDArray[np.float64],
        data_x: np.typing.NDArray[np.float64],
    ) -> tuple[list[tuple[float, float]], np.typing.NDArray[np.float64]]:
        """Compute partial dependence function for each epoch.

        For each tree grid family (epoch), computes:
            E_{X_S^c}[f(X_S, X_S^c)] = (1/n) ∑_{i=1}^n f(x_S, x_{i,S^c})

        where S is the set of fixed features, S^c is the complement (marginalized features),
        and the expectation is taken over the empirical joint distribution of X_S^c in data_x.

        This correctly preserves feature correlations in the marginalized features, unlike
        methods that assume independence.

        This is useful for:
        - Partial dependence plots: fix some features, marginalize over others
        - Feature effect analysis: see how model responds to changes in specific features
        - Conditional predictions: predict while averaging over unobserved features

        Parameters
        ----------
        fixed_indices : list[int]
            List of feature indices that are fixed (in order).
            The order determines which column in fixed_values corresponds to which feature.
        fixed_values : np.ndarray of shape (n_observations, len(fixed_indices))
            Array of fixed feature values. Each row is one observation.
            Columns correspond to the order of fixed_indices.
            For example, if fixed_indices=[0, 1], then:
            - fixed_values[:, 0] contains values for feature 0
            - fixed_values[:, 1] contains values for feature 1
        data_x : np.ndarray of shape (n_samples, n_features)
            Training data used to estimate the empirical distribution of marginalized features.
            Should be the same data the model was trained on, or a representative sample.

        Returns
        -------
        tuple of (constants_per_epoch, pd_values)
            - constants_per_epoch: list of (C_plus, C_minus) per epoch
              Constants are E[∏_{j ∉ S} f_j(X_j)] (expectation over marginalized features)
              Constants include OLS scaling (computed with effective_lambda = scaling * lambda)
            - pd_values: np.ndarray of shape (n_observations, 2 * n_epochs)
              Array of marginal expectation values for f+ and f- separately.
              - Rows correspond to observations (same order as fixed_values)
              - Columns alternate: [f+_epoch0, f-_epoch0, f+_epoch1, f-_epoch1, ...]
              - Values have scaling absorbed into lambda, so final prediction per epoch = f+ + f- (add columns)
              - To get total across epochs: sum (f+ + f-) across all epochs

        Examples
        --------
        >>> # Single observation: fix latitude and longitude
        >>> fixed_indices = [0, 1]  # Latitude, Longitude
        >>> fixed_values = np.array([[37.5, -122.2]])  # Single observation
        >>> constants, marginal_exp = model.compute_partial_dependence_function(fixed_indices, fixed_values, X_train)
        >>> # constants: list of (C_plus, C_minus) per epoch
        >>> # marginal_exp shape: (1, 2 * n_epochs)
        >>> # Access f+ and f- for epoch 0:
        >>> f_plus_epoch0 = marginal_exp[0, 0]
        >>> f_minus_epoch0 = marginal_exp[0, 1]
        >>> # Access constants for epoch 0:
        >>> c_plus_epoch0, c_minus_epoch0 = constants[0]
        >>>
        >>> # Multiple observations: partial dependence plot
        >>> latitudes = np.linspace(32, 42, 50)
        >>> fixed_indices = [0]  # Only fix latitude
        >>> fixed_values = latitudes.reshape(-1, 1)  # Shape: (50, 1)
        >>> constants, marginal_exp = model.compute_partial_dependence_function(fixed_indices, fixed_values, X_train)
        >>> # marginal_exp shape: (50, 2 * n_epochs)
        >>> # Extract f+ for all epochs: columns 0, 2, 4, ...
        >>> f_plus_all = marginal_exp[:, ::2]  # Shape: (50, n_epochs)
        >>> # Extract f- for all epochs: columns 1, 3, 5, ...
        >>> f_minus_all = marginal_exp[:, 1::2]  # Shape: (50, n_epochs)

        Notes
        -----
        - Uses the two-tensor representation: f = λ₊ * ∏_j a₊,j - λ₋ * ∏_j a₋,j
        - OLS scaling is absorbed into lambda values (effective_lambda = scaling * lambda)
        - Returns f+ and f- separately with scaling already absorbed
        - Final prediction per epoch = f+ + f- (add the two columns)
        - Preserves joint distribution of marginalized features (not assuming independence)
        - Computed efficiently in Rust for performance
        """
        # Ensure arrays are contiguous before passing to Rust
        fixed_values = np.ascontiguousarray(fixed_values, dtype=np.float64)
        data_x = np.ascontiguousarray(data_x, dtype=np.float64)

        # Validate dimensions
        if fixed_values.ndim != 2:
            raise ValueError(
                f"fixed_values must be 2D array, got shape {fixed_values.shape}"
            )
        if fixed_values.shape[1] != len(fixed_indices):
            raise ValueError(
                f"fixed_values must have {len(fixed_indices)} columns "
                f"(one per fixed_indices), got {fixed_values.shape[1]}"
            )

        return self._rust_instance.compute_partial_dependence_function(
            fixed_indices, fixed_values, data_x
        )

    def compute_first_order_partial_dependence_functions(
        self,
        values_x: np.typing.NDArray[np.float64],
        data_x: np.typing.NDArray[np.float64],
    ) -> list[tuple[list[tuple[float, float]], np.typing.NDArray[np.float64]]]:
        """Compute first-order partial dependence functions for each feature.

        For each epoch and feature j, computes constants:
            C_{+,j} = E[lambda_+ * ∏_{k != j} a_{+,k}(X_k)]
            C_{-,j} = E[lambda_- * ∏_{k != j} a_{-,k}(X_k)]

        Then:
            PD_{+,j}(x_j) = C_{+,j} * a_{+,j}(x_j)
            PD_{-,j}(x_j) = C_{-,j} * a_{-,j}(x_j)

        Parameters
        ----------
        values_x : np.ndarray of shape (n_observations, n_features)
            Values at which to evaluate the per-feature PD functions. Column j
            corresponds to the feature j values.
        data_x : np.ndarray of shape (n_samples, n_features)
            Background data used to estimate the empirical expectation.

        Returns
        -------
        list of length n_features
            Each entry is a tuple:
              (con/stants_per_epoch, pd_values)
            - constants_per_epoch: list of (C_plus, C_minus) per epoch
              Constants include OLS scaling (computed with effective_lambda = scaling * lambda)
            - pd_values: array of shape (n_observations, 2 * n_epochs) with columns
              [f+_epoch0, f-_epoch0, f+_epoch1, f-_epoch1, ...]
              Values have scaling absorbed into constants, so final prediction per epoch = f+ + f- (add columns)
        """
        values_x = np.ascontiguousarray(values_x, dtype=np.float64)
        data_x = np.ascontiguousarray(data_x, dtype=np.float64)

        if values_x.ndim != 2:
            raise ValueError(f"values_x must be 2D array, got shape {values_x.shape}")
        if data_x.ndim != 2:
            raise ValueError(f"data_x must be 2D array, got shape {data_x.shape}")
        if values_x.shape[1] != data_x.shape[1]:
            raise ValueError(
                f"values_x and data_x must have the same number of columns, "
                f"got {values_x.shape[1]} and {data_x.shape[1]}"
            )

        return self._rust_instance.compute_first_order_partial_dependence_functions(
            values_x, data_x
        )

    def compute_ice_curves(
        self,
        observations: np.typing.NDArray[np.float64],
        feature_index: int,
        x_range: np.typing.NDArray[np.float64],
        data_x: np.typing.NDArray[np.float64],
    ) -> np.typing.NDArray[np.float64]:
        """Compute Individual Conditional Expectation (ICE) curves for a single feature.

        For each observation, varies the specified feature over the provided range while
        keeping all other features fixed at that observation's values. Computes f+ and f-
        separately for each epoch.

        Parameters
        ----------
        observations : np.ndarray of shape (n_obs, n_features)
            Observations to compute ICE curves for.
        feature_index : int
            Index of the feature to vary.
        x_range : np.ndarray of shape (n_range_values,)
            Values to evaluate for the varying feature.
        data_x : np.ndarray of shape (n_samples, n_features)
            Training data (used for validation only).

        Returns
        -------
        np.ndarray of shape (n_obs, n_range_values, 2 * n_epochs)
            ICE curve values. Last dimension alternates: [f+_epoch0, f-_epoch0, f+_epoch1, f-_epoch1, ...]
            Values have scaling applied: scaling_plus * f_+ and scaling_minus * (-f_-)

        Examples
        --------
        >>> # Compute ICE curves for feature 0
        >>> x_range = np.linspace(X[:, 0].min(), X[:, 0].max(), 50)
        >>> ice_values = model.compute_ice_curves(X[:10], feature_index=0, x_range=x_range, data_x=X)
        >>> # ice_values.shape = (10, 50, 2 * n_epochs)
        >>> # For epoch 0, observation 0:
        >>> f_plus_curve = ice_values[0, :, 0]  # shape (50,)
        >>> f_minus_curve = ice_values[0, :, 1]  # shape (50,)
        """
        # Ensure arrays are contiguous before passing to Rust
        observations = np.ascontiguousarray(observations, dtype=np.float64)
        x_range = np.ascontiguousarray(x_range, dtype=np.float64)
        data_x = np.ascontiguousarray(data_x, dtype=np.float64)

        # Validate dimensions
        if observations.ndim != 2:
            raise ValueError(
                f"observations must be 2D array, got shape {observations.shape}"
            )
        if x_range.ndim != 1:
            raise ValueError(f"x_range must be 1D array, got shape {x_range.shape}")
        if data_x.ndim != 2:
            raise ValueError(f"data_x must be 2D array, got shape {data_x.shape}")
        if observations.shape[1] != data_x.shape[1]:
            raise ValueError(
                f"observations and data_x must have the same number of columns, "
                f"got {observations.shape[1]} and {data_x.shape[1]}"
            )

        return self._rust_instance.compute_ice_curves(
            observations, feature_index, x_range, data_x
        )

    def save(self, path: str) -> None:
        """Save the MPF model to a binary file (preserves exact floating point values).

        **Note**: Binary format is same-version-only. Models saved with one version of mpf-py
        may not load with a different version due to schema changes. For portability across
        versions, consider exporting model parameters or predictions instead.

        Parameters
        ----------
        path : str
            Path to the file where the model will be saved.
        """
        self._rust_instance.save(path)

    @classmethod
    def load(cls, path: str) -> "MPF":
        """Load an MPF model from a binary file.

        **Note**: Binary format is same-version-only. Models saved with one version of mpf-py
        may not load with a different version due to schema changes.

        Parameters
        ----------
        path : str
            Path to the binary file containing the saved model.

        Returns
        -------
        MPF
            The loaded MPF model instance.
        """
        rust_instance = _MPF.load(path)
        instance = cls(rust_instance)
        # Reconstruct the tree_grid_families_lst from the Rust instance
        try:
            instance._tree_grid_families_lst = [
                [TreeGrid(tg) for tg in tf.tree_grids]
                for tf in rust_instance.tree_grid_families
            ]
        except (AttributeError, TypeError) as e:
            # If tree_grid_families is not accessible, leave _tree_grid_families_lst unset
            # This is not critical for predictions, only for plotting
            import warnings

            warnings.warn(
                f"Could not reconstruct _tree_grid_families_lst from loaded model: {e}. "
                "Plotting methods may not work, but predictions will still function correctly.",
                UserWarning,
            )
        return instance

    def plot_combined_tree_grids(
        self, individual_plots: bool = True, axis: Optional[int] = None
    ):
        for tgf in self.tree_grid_families:
            combined_tg = TreeGrid(tgf.combined_tree_grid)
            combined_tg.plot_components(individual_plots=individual_plots, axis=axis)

    def plot_epoch_components(self, epoch: int) -> None:
        """
        Plot one-dimensional components for all tree grids at a given epoch.

        This creates one separate figure per component (feature). In each figure,
        it overlays step-function lines for all n_grids tree grids from the specified
        epoch in different colors.

        Args:
            epoch: Zero-based epoch index to visualize. Must be within the range of
                   available epochs.
        """
        try:
            import matplotlib.pyplot as plt
        except ImportError:
            raise ImportError(
                "Matplotlib is required to use 'plot_epoch_components'. "
                "Please install it using: pip install matplotlib"
            )

        # Validate epoch and retrieve all tree grids for that epoch
        if not hasattr(self, "_tree_grid_families_lst"):
            raise RuntimeError(
                "Internal tree grid cache not initialized. Ensure the model was fitted via MPF.fit()."
            )

        total_epochs = len(self._tree_grid_families_lst)
        if epoch < 0 or epoch >= total_epochs:
            raise ValueError(f"epoch must be between 0 and {total_epochs - 1}")

        epoch_tree_grids = self._tree_grid_families_lst[epoch]
        if len(epoch_tree_grids) == 0:
            raise ValueError(f"No tree grids found for epoch {epoch}")

        # Infer number of components (features) from the first grid
        first_grid = epoch_tree_grids[0]
        num_components = len(first_grid.intervals)

        # Prepare colors (distinct color per grid)
        colors = plt.cm.tab10(np.linspace(0, 1, max(10, len(epoch_tree_grids))))

        # Create one figure per component
        for component_index in range(num_components):
            fig, ax = plt.subplots(1, 1, figsize=(10, 4))

            for grid_index, grid in enumerate(epoch_tree_grids):
                intervals = grid.intervals[component_index]
                values = grid.mean_factor[component_index]

                # Build step function points, skipping infinite bounds
                x_points: list[float] = []
                y_points: list[float] = []
                for (x_start, x_end), y in zip(intervals, values):
                    if x_start == float("-inf") or x_end == float("inf"):
                        continue
                    x_points.extend([x_start, x_end])
                    y_points.extend([y, y])

                if x_points:
                    ax.step(
                        x_points,
                        y_points,
                        where="post",
                        lw=1,
                        color=colors[grid_index % len(colors)],
                        label=f"Grid {grid_index}",
                    )

            ax.set_xlabel("X")
            ax.set_ylabel("Value")
            ax.set_title(f"Epoch {epoch} — Component {component_index}")
            ax.grid(True)
            ax.legend(loc="best", fontsize="small")

            plt.tight_layout()
            plt.show()
