use serde::Serialize;

use crate::grid::{refinement::RefinementStrategy, splitting::SplitStrategy};

#[derive(Debug, Clone)]
pub enum RefinementStrategyType {
    L2,
    Huber,
}

#[derive(Debug, Clone)]
pub struct RefinementStrategyParamsBuilder {
    strategy_type: RefinementStrategyType,
    alpha: f64,
    /// Two-tensor L2 coupling between u_+ and u_- (objective τ).
    tilt_tau: f64,
    /// Two-tensor L1 coupling on (u_+ - u_-) (objective ρ).
    tilt_rho: f64,
    /// Prior sample size for parent anchoring (tau_0).
    /// Interpreted as "how many samples worth of confidence in the parent".
    /// Default: 0.0 (no anchoring). Typical values: 10-50.
    prior_sample_size: f64,
    update_clamp: f64,
}

impl RefinementStrategyParamsBuilder {
    pub fn new() -> Self {
        Self {
            strategy_type: RefinementStrategyType::L2,
            alpha: 0.0,
            tilt_tau: crate::grid::two_tensor_solver::DEFAULT_TAU,
            tilt_rho: crate::grid::two_tensor_solver::DEFAULT_RHO,
            prior_sample_size: 0.0,      // Default: no anchoring
            update_clamp: f64::INFINITY, // Default: no clamping
        }
    }

    pub fn l2(mut self) -> Self {
        self.strategy_type = RefinementStrategyType::L2;
        self
    }

    pub fn huber(mut self) -> Self {
        self.strategy_type = RefinementStrategyType::Huber;
        self
    }

    pub fn alpha(mut self, alpha: f64) -> Self {
        self.alpha = alpha;
        self
    }

    /// Two-tensor L2 coupling between u_+ and u_- (objective τ).
    pub fn tilt_tau(mut self, tilt_tau: f64) -> Self {
        self.tilt_tau = tilt_tau;
        self
    }

    /// Two-tensor L1 coupling on (u_+ - u_-) (objective ρ).
    pub fn tilt_rho(mut self, tilt_rho: f64) -> Self {
        self.tilt_rho = tilt_rho;
        self
    }

    /// Set the prior sample size for parent anchoring (tau_0).
    ///
    /// This is interpreted as "how many samples worth of confidence we have that
    /// children should equal their parent". With tau_0 = 30, a child with 10 samples
    /// will be heavily shrunk toward the parent, while a child with 100 samples
    /// will mostly trust its own data.
    ///
    /// Default: 0.0 (no anchoring). Typical values: 10-50.
    pub fn prior_sample_size(mut self, tau_0: f64) -> Self {
        self.prior_sample_size = tau_0;
        self
    }

    pub fn update_clamp(mut self, update_clamp: f64) -> Self {
        self.update_clamp = update_clamp;
        self
    }

    pub fn build(self) -> RefinementStrategyParams {
        match self.strategy_type {
            RefinementStrategyType::L2 => RefinementStrategyParams::L2 {
                alpha: self.alpha,
                tilt_tau: self.tilt_tau,
                tilt_rho: self.tilt_rho,
                prior_sample_size: self.prior_sample_size,
                update_clamp: self.update_clamp,
            },
            RefinementStrategyType::Huber => RefinementStrategyParams::Huber {
                alpha: self.alpha,
                tilt_tau: self.tilt_tau,
                tilt_rho: self.tilt_rho,
                prior_sample_size: self.prior_sample_size,
                update_clamp: self.update_clamp,
            },
        }
    }
}

impl Default for RefinementStrategyParamsBuilder {
    fn default() -> Self {
        Self::new()
    }
}

#[derive(Debug, Clone)]
pub enum SplitStrategyType {
    RandomSplit,
    BestSplit,
    TopKSplits,
}

#[derive(Debug, Clone)]
pub struct SplitStrategyParamsBuilder {
    strategy_type: SplitStrategyType,
    split_try: usize,
    colsample_bytree: f64,
    min_interval_samples: usize,
    top_k: usize,
    must_fill_all_k: bool,
    min_split_loss: f64,
    /// Complexity penalty (lambda) for adaptive merge bonus.
    /// Default: 0.0 (no complexity penalty). Typical values: 0.5-2.0.
    complexity_penalty: f64,
}

impl SplitStrategyParamsBuilder {
    pub fn new() -> Self {
        Self {
            strategy_type: SplitStrategyType::RandomSplit,
            split_try: 10,
            colsample_bytree: 1.0,
            min_interval_samples: 1,
            min_split_loss: 0.0,
            complexity_penalty: 0.0, // Default: no complexity penalty
            top_k: 5,
            must_fill_all_k: false,
        }
    }

    pub fn random_split(mut self) -> Self {
        self.strategy_type = SplitStrategyType::RandomSplit;
        self
    }

    pub fn best_split(mut self) -> Self {
        self.strategy_type = SplitStrategyType::BestSplit;
        self
    }

    pub fn top_k_splits(mut self) -> Self {
        self.strategy_type = SplitStrategyType::TopKSplits;
        self
    }

    pub fn split_try(mut self, split_try: usize) -> Self {
        self.split_try = split_try;
        self
    }

    pub fn colsample_bytree(mut self, colsample_bytree: f64) -> Self {
        self.colsample_bytree = colsample_bytree;
        self
    }

    pub fn min_interval_samples(mut self, min_interval_samples: usize) -> Self {
        self.min_interval_samples = min_interval_samples;
        self
    }

    pub fn min_split_loss(mut self, min_split_loss: f64) -> Self {
        self.min_split_loss = min_split_loss;
        self
    }

    /// Set the complexity penalty (lambda) for adaptive merge bonus.
    ///
    /// The merge bonus is computed as:
    ///   bonus = lambda * MSE * (log(n)/n + 1/harmonic_mean(n_left, n_right))
    ///
    /// This is BIC-inspired and scale-invariant. Larger lambda encourages simpler models.
    /// Default: 0.0 (no complexity penalty). Typical values: 0.5-2.0.
    pub fn complexity_penalty(mut self, lambda: f64) -> Self {
        self.complexity_penalty = lambda;
        self
    }

    pub fn top_k(mut self, top_k: usize) -> Self {
        self.top_k = top_k;
        self
    }

    pub fn must_fill_all_k(mut self, must_fill_all_k: bool) -> Self {
        self.must_fill_all_k = must_fill_all_k;
        self
    }

    pub fn build(self) -> SplitStrategyParams {
        match self.strategy_type {
            SplitStrategyType::RandomSplit => SplitStrategyParams::RandomSplit {
                split_try: self.split_try,
                colsample_bytree: self.colsample_bytree,
                min_interval_samples: self.min_interval_samples,
                min_split_loss: self.min_split_loss,
                complexity_penalty: self.complexity_penalty,
            },
            SplitStrategyType::BestSplit => SplitStrategyParams::BestSplit {
                min_interval_samples: self.min_interval_samples,
                min_split_loss: self.min_split_loss,
                complexity_penalty: self.complexity_penalty,
            },
            SplitStrategyType::TopKSplits => SplitStrategyParams::TopKSplits {
                top_k: self.top_k,
                must_fill_all_k: self.must_fill_all_k,
                min_interval_samples: self.min_interval_samples,
                min_split_loss: self.min_split_loss,
                complexity_penalty: self.complexity_penalty,
            },
        }
    }
}

impl Default for SplitStrategyParamsBuilder {
    fn default() -> Self {
        Self::new()
    }
}

// Random splitting (non-exact) is deprecated in favor of exact splitter only.

#[derive(Debug, Clone, Serialize)]
pub enum SplitStrategyParams {
    // Keep only these three strategies
    RandomSplit {
        split_try: usize,
        colsample_bytree: f64,
        min_interval_samples: usize,
        min_split_loss: f64,
        /// Complexity penalty (lambda) for adaptive merge bonus.
        complexity_penalty: f64,
    },
    BestSplit {
        min_interval_samples: usize,
        min_split_loss: f64,
        complexity_penalty: f64,
    },
    TopKSplits {
        top_k: usize,
        must_fill_all_k: bool,
        min_interval_samples: usize,
        min_split_loss: f64,
        complexity_penalty: f64,
    },
}

impl SplitStrategyParams {
    // Non-exact splitting is removed; callers should use get_split_strategy_exact

    pub fn get_split_strategy(&self) -> SplitStrategy {
        match self {
            SplitStrategyParams::BestSplit {
                min_interval_samples,
                complexity_penalty,
                min_split_loss,
            } => SplitStrategy::Best {
                min_interval_samples: *min_interval_samples,
                complexity_penalty: *complexity_penalty,
                min_split_loss: *min_split_loss,
            },
            SplitStrategyParams::TopKSplits {
                top_k,
                must_fill_all_k,
                min_interval_samples,
                complexity_penalty,
                min_split_loss,
            } => SplitStrategy::TopK {
                top_k: *top_k,
                must_fill_all_k: *must_fill_all_k,
                min_interval_samples: *min_interval_samples,
                complexity_penalty: *complexity_penalty,
                min_split_loss: *min_split_loss,
            },
            SplitStrategyParams::RandomSplit {
                split_try,
                colsample_bytree,
                min_interval_samples,
                complexity_penalty,
                min_split_loss,
            } => SplitStrategy::Random {
                split_try: *split_try,
                colsample_bytree: *colsample_bytree,
                min_interval_samples: *min_interval_samples,
                complexity_penalty: *complexity_penalty,
                min_split_loss: *min_split_loss,
            },
        }
    }
}

#[derive(Debug, Clone, PartialEq, Serialize)]
pub enum RefinementStrategyParams {
    L2 {
        alpha: f64,
        tilt_tau: f64,
        tilt_rho: f64,
        /// Prior sample size for parent anchoring (tau_0).
        /// Interpreted as "how many samples worth of confidence in the parent".
        prior_sample_size: f64,
        update_clamp: f64,
    },
    Huber {
        alpha: f64,
        tilt_tau: f64,
        tilt_rho: f64,
        prior_sample_size: f64,
        update_clamp: f64,
    },
}

impl RefinementStrategyParams {
    pub fn get_refinement_strategy(&self) -> RefinementStrategy {
        match self {
            RefinementStrategyParams::L2 {
                alpha,
                tilt_tau,
                tilt_rho,
                prior_sample_size,
                update_clamp,
            } => RefinementStrategy::L2Refinement {
                alpha: *alpha,
                tilt_tau: *tilt_tau,
                tilt_rho: *tilt_rho,
                prior_sample_size: *prior_sample_size,
                update_clamp: *update_clamp,
            },
            RefinementStrategyParams::Huber {
                alpha,
                tilt_tau,
                tilt_rho,
                prior_sample_size,
                update_clamp,
            } => RefinementStrategy::HuberRefinement {
                alpha: *alpha,
                c: 1.345,
                tilt_tau: *tilt_tau,
                tilt_rho: *tilt_rho,
                prior_sample_size: *prior_sample_size,
                update_clamp: *update_clamp,
            },
        }
    }
}

#[derive(Debug, Clone, Serialize)]
pub struct TreeGridParams {
    pub n_iter: usize,
    pub split_strategy_params: SplitStrategyParams,
    pub refinement_strategy_params: RefinementStrategyParams,
}

// Builder for TreeGridParams
#[derive(Debug, Clone)]
pub struct TreeGridParamsBuilder {
    n_iter: usize,
    split_strategy_params: SplitStrategyParams,
    refinement_strategy_params: RefinementStrategyParams,
}

impl TreeGridParamsBuilder {
    pub fn new() -> Self {
        Self {
            n_iter: 25,
            split_strategy_params: SplitStrategyParams::RandomSplit {
                split_try: 10,
                colsample_bytree: 1.0,
                min_interval_samples: 1,
                min_split_loss: 0.0,
                complexity_penalty: 0.0,
            },

            refinement_strategy_params: RefinementStrategyParams::L2 {
                alpha: 0.0,
                tilt_tau: crate::grid::two_tensor_solver::DEFAULT_TAU,
                tilt_rho: crate::grid::two_tensor_solver::DEFAULT_RHO,
                prior_sample_size: 0.0,
                update_clamp: f64::INFINITY,
            },
        }
    }

    pub fn n_iter(mut self, n_iter: usize) -> Self {
        self.n_iter = n_iter;
        self
    }

    pub fn split_strategy(mut self, strategy: SplitStrategyParams) -> Self {
        self.split_strategy_params = strategy;
        self
    }

    pub fn refinement_strategy(mut self, strategy: RefinementStrategyParams) -> Self {
        self.refinement_strategy_params = strategy;
        self
    }

    pub fn build(self) -> TreeGridParams {
        TreeGridParams {
            n_iter: self.n_iter,
            split_strategy_params: self.split_strategy_params,
            refinement_strategy_params: self.refinement_strategy_params,
        }
    }
}

impl Default for TreeGridParamsBuilder {
    fn default() -> Self {
        Self::new()
    }
}

impl Default for TreeGridParams {
    fn default() -> Self {
        TreeGridParamsBuilder::new().build()
    }
}
