use itertools::Itertools;
use ndarray::{ArrayView1, ArrayView2};
use rand::{Rng, SeedableRng};

use crate::{
    family::{
        aggregate_bagged::aggregate_bagged_two_tensor,
        combine_grids::combine_two_tensor_grids_geometric_mean,
        similarity::{compute_pairwise_similarity_backbone_and_tilt, find_medoid_index},
    },
    grid::{self},
    logging::{self},
    FitResult,
};

use super::super::grid::FittedTreeGrid;
use super::TreeGridFamily;
use super::{
    params::{CombinationStrategyParams, TreeGridFamilyParams},
    projection::project_grid_to_vectors,
    reconstruction::reconstruct_grid,
    reference_grid::ReferenceGrid,
    tensor_power::compute_consensus,
};

#[cfg(feature = "use-rayon")]
use rayon::prelude::*;

/// Compute bootstrap sample-weights for bagged sampling.
///
/// Samples n indices with replacement and returns a weight vector where
/// weight[i] is the count of how many times observation i was sampled.
///
/// **Invariant I34**: For bagged sampling, the sum of sample-weights must equal n.
fn bootstrap_sample_weights<R: Rng + ?Sized>(n: usize, rng: &mut R) -> Vec<f64> {
    let mut counts = vec![0usize; n];
    for _ in 0..n {
        let i = rng.gen_range(0..n);
        counts[i] += 1;
    }
    debug_assert_eq!(counts.iter().sum::<usize>(), n);
    counts.into_iter().map(|c| c as f64).collect()
}

/// Fit a single tree grid, handling bagged sampling with sample-weights
/// Uses unified fit() function that works with or without logging
fn fit_single_tree_grid<R: Rng + ?Sized>(
    x: ArrayView2<f64>,
    y: ArrayView1<f64>,
    tg_params: &grid::TreeGridParams,
    bagged: bool,
    rng: &mut R,
) -> (FitResult, FittedTreeGrid) {
    grid::fit(x.view(), y.view(), tg_params, rng)
}

/// Fit a single tree grid with tree_id context and logging
/// Uses unified fit() function - logging handled automatically through channel
fn fit_tree_grid_with_context(
    tree_id: usize,
    seed: u64,
    x: ArrayView2<f64>,
    y: ArrayView1<f64>,
    tg_params: &grid::TreeGridParams,
    bagged: bool,
) -> (FitResult, FittedTreeGrid) {
    logging::with_tree_id(tree_id, || {
        let mut thread_rng = rand::rngs::StdRng::seed_from_u64(seed);

        // fit() works with or without logging - no special variant needed
        // If event channel is set up, events go through automatically
        // If not, events are buffered in LoggingState (if initialized)
        let (fit_res, tg) = fit_single_tree_grid(x, y, tg_params, bagged, &mut thread_rng);

        log::debug!(
            "Fitted tree grid, lambda+: {:?}, lambda-: {:?}, error: {:?}",
            tg.lambda_plus,
            tg.lambda_minus,
            fit_res.err
        );

        // Log per-tree training error using logging function (handles feature flags internally)
        use crate::logging::{log_grid_error_fitted, GridErrorVariant};
        log_grid_error_fitted(fit_res.err, GridErrorVariant::Train);

        (fit_res, tg)
    })
}

pub fn fit_ensemble<R: Rng + ?Sized>(
    x: ArrayView2<f64>,
    y: ArrayView1<f64>,
    hyperparameters: &TreeGridFamilyParams,
    rng: &mut R,
) -> (FitResult, TreeGridFamily) {
    let TreeGridFamilyParams {
        n_trees,
        bagged,
        tg_params,
        combination_strategy,
        similarity_threshold,
        aggregation_method,
    } = hyperparameters;

    // Capture epoch context from thread-local storage before spawning parallel workers
    // This is needed because thread-local storage doesn't propagate to Rayon worker threads
    let epoch_opt = logging::current_epoch();

    // Pre-generate seeds for each thread
    let seeds: Vec<u64> = (0..*n_trees).map(|_| rng.gen()).collect();

    // Fit all tree grids (sequential or parallel based on feature flags)
    // Important: When using rayon, we need to preserve tree_id order by creating (tree_id, seed) pairs first
    // This ensures tree_grids[i] corresponds to tree_id=i even when parallel execution completes out of order
    let tree_grids: Vec<FittedTreeGrid>;

    #[cfg(not(feature = "use-rayon"))]
    {
        let (_, grids): (Vec<FitResult>, Vec<FittedTreeGrid>) = seeds
            .iter()
            .enumerate()
            .map(|(tree_id, &seed)| {
                fit_tree_grid_with_context(tree_id, seed, x, y, tg_params, *bagged)
            })
            .unzip();
        tree_grids = grids;
    }

    #[cfg(feature = "use-rayon")]
    {
        // Create (tree_id, seed) pairs to preserve order when collecting parallel results
        // Rayon's collect() does NOT guarantee order preservation, so we need to sort by tree_id
        // after collecting to ensure tree_grids[i] corresponds to tree_id=i
        let tree_seed_pairs: Vec<(usize, u64)> = (0..*n_trees).zip(seeds).collect();
        let mut results: Vec<(usize, FitResult, FittedTreeGrid)> = tree_seed_pairs
            .into_par_iter()
            .map(|(tree_id, seed)| {
                // Set epoch context in each worker thread since thread-local storage doesn't propagate
                // If epoch was set on the main thread, set it here too
                if let Some(epoch) = epoch_opt {
                    logging::with_epoch(epoch, || {
                        let (fit_res, grid) =
                            fit_tree_grid_with_context(tree_id, seed, x, y, tg_params, *bagged);
                        (tree_id, fit_res, grid)
                    })
                } else {
                    // No epoch context - this shouldn't happen in normal flow, but handle gracefully
                    let (fit_res, grid) =
                        fit_tree_grid_with_context(tree_id, seed, x, y, tg_params, *bagged);
                    (tree_id, fit_res, grid)
                }
            })
            .collect();
        // Sort by tree_id to ensure correct order for combination choice logging
        results.sort_by_key(|(tree_id, _, _)| *tree_id);
        tree_grids = results.into_iter().map(|(_, _, grid)| grid).collect();
    }

    // Choose combination method based on strategy
    let (candidate_indices, primary_tree_grid) = match combination_strategy {
        CombinationStrategyParams::TensorPower {
            n_bins,
            n_iterations,
        } => {
            // Use Tensor Power Method
            log::info!(
                "Using Tensor Power Method: n_bins={}, n_iterations={}",
                n_bins,
                n_iterations
            );

            // Step 1: Create ReferenceGrid from training data
            let ref_grid = ReferenceGrid::from_data(x.view(), *n_bins);

            // Step 2: Project each bag to vectors
            let mut bagged_components = Vec::new();
            let mut bag_lambdas = Vec::new();

            for grid in &tree_grids {
                let (projected_vectors, new_lambda) = project_grid_to_vectors(grid, &ref_grid);
                bagged_components.push(projected_vectors);
                bag_lambdas.push(new_lambda);
            }

            // Step 3: Compute consensus
            let (consensus_components, consensus_scale) =
                compute_consensus(&bagged_components, &bag_lambdas, *n_iterations);

            // Step 4: Reconstruct grid
            let combined_grid =
                reconstruct_grid(&ref_grid, consensus_components, consensus_scale, x.view());

            // All grids are candidates for tensor power method
            let candidate_indices: Vec<usize> = (0..tree_grids.len()).collect();
            (candidate_indices, combined_grid)
        }
        CombinationStrategyParams::GeometricMeanTwoTensor => {
            // Use two-tensor geometric mean combination
            log::info!("Using two-tensor geometric mean combination");

            // Optimization: if similarity_threshold is 0, combine all grids without computing pairwise distances
            if *similarity_threshold == 0.0 {
                log::info!(
                    "Similarity threshold is 0.0, combining all {} grids without distance computation",
                    tree_grids.len()
                );
                let candidate_indices: Vec<(usize, f64)> =
                    (0..tree_grids.len()).map(|i| (i, 0.0)).collect();
                crate::logging::log_combination_choice(
                    "CosineSimilarity",
                    None,
                    &candidate_indices,
                );

                // Combine all grids with equal weights
                let weights = vec![1.0; tree_grids.len()];
                let combined_grid =
                    combine_two_tensor_grids_geometric_mean(&tree_grids, Some(&weights), x.view());
                ((0..tree_grids.len()).collect(), combined_grid)
            } else {
                // // First align signs to determine candidate grids (returns scored candidates)
                // let (best_index, scored_candidates) =
                //     find_best_reference_index_and_top_k_candidates(&tree_grids, *similarity_threshold);
                // // Log combination choice (method, best index and scored candidates)
                // crate::logging::log_combination_choice(
                //     "GeometricMeanTwoTensor",
                //     Some(best_index),
                //     &scored_candidates,
                // );
                // let candidate_indices = scored_candidates.iter().map(|(i, _)| *i).collect();
                // // Create weights: 0 for non-candidates, 1.0 for candidates
                // // (normalization happens automatically in combination functions via total_weight)
                // let mut weights = vec![0.0; tree_grids.len()];
                // for &candidate_idx in &candidate_indices {
                //     weights[candidate_idx] = 1.0;
                // }

                let (best_index, pairwise_lambda_distances, sum_distances) =
                    find_medoid_index(&tree_grids);
                let reference_grid = &tree_grids[best_index];
                let pairwise_similarity_backbone_and_tilt =
                    compute_pairwise_similarity_backbone_and_tilt(
                        reference_grid,
                        &tree_grids,
                        x.view(),
                    );
                let distances_to_medoid = pairwise_lambda_distances
                    .into_iter()
                    .nth(best_index)
                    .unwrap();
                let total_distance_to_medoid = sum_distances[best_index];

                let combined_similarities = pairwise_similarity_backbone_and_tilt
                    .iter()
                    .enumerate()
                    .map(|(i, (similarity_backbone, similarity_tilt))| {
                        let cond_threshold = (total_distance_to_medoid - distances_to_medoid[i])
                            / total_distance_to_medoid;
                        cond_threshold * (similarity_backbone + 1.0) * (similarity_tilt + 1.0) / 4.0
                    })
                    .collect::<Vec<f64>>();

                let top_k = ((1.0 - *similarity_threshold) * tree_grids.len() as f64) as usize;
                let candidate_indices: Vec<(usize, f64)> = combined_similarities
                    .iter()
                    .enumerate()
                    .sorted_by(|a, b| a.1.partial_cmp(b.1).unwrap())
                    .take(top_k)
                    .map(|(i, similarity)| (i, *similarity))
                    .collect::<Vec<(usize, f64)>>();

                log::info!("Candidate indices: {:?}", candidate_indices);
                crate::logging::log_combination_choice(
                    "CosineSimilarity",
                    Some(best_index),
                    &candidate_indices,
                );

                let mut weights = vec![0.0; tree_grids.len()];
                for (i, _) in &candidate_indices {
                    weights[*i] = 1.0;
                }
                let combined_grid =
                    combine_two_tensor_grids_geometric_mean(&tree_grids, Some(&weights), x.view());

                // let combined_grid: FittedTreeGrid = combine_median_two_tensor_grids_geometric_mean(
                //     &tree_grids,
                //     *similarity_threshold,
                //     x.view(),
                // );
                ((0..tree_grids.len()).collect(), combined_grid)
            }
        }
        CombinationStrategyParams::BaggedTwoTensor => {
            // Use bagged two-tensor aggregation with component-shape distance
            // Convert similarity_threshold to trim_percentage: keep (1.0 - similarity_threshold) fraction
            let trim_percentage = 1.0 - *similarity_threshold;
            log::info!(
                "Using bagged two-tensor aggregation with similarity_threshold={}, trim_percentage={}",
                similarity_threshold,
                trim_percentage
            );

            let combined_grid = aggregate_bagged_two_tensor(
                &tree_grids,
                x.view(),
                None, // TODO: Support weighted bin weights if needed
                trim_percentage,
            );

            // All grids are candidates for bagged aggregation
            let candidate_indices: Vec<usize> = (0..tree_grids.len()).collect();
            (candidate_indices, combined_grid)
        }
    };

    let tgf = TreeGridFamily::new_ensemble(
        tree_grids,
        primary_tree_grid,
        candidate_indices,
        aggregation_method.clone(),
    );

    let preds = tgf.predict(x);
    log::debug!("Combined preds mean: {:?}", preds.mean());

    let residuals = &y - &preds;
    let err = residuals.pow2().mean().unwrap();
    log::info!(
        "Combined tree grid error: {:?}, lambda+: {:?}, lambda-: {:?}",
        err,
        tgf.primary_tree_grid.lambda_plus,
        tgf.primary_tree_grid.lambda_minus
    );
    // Log family-level combined grid error
    // Feature flag is encapsulated inside the logging function
    use crate::logging::{log_grid_error_combined, GridErrorVariant};
    log_grid_error_combined(err, GridErrorVariant::Train);
    (
        FitResult {
            err,
            residuals: residuals.to_owned(),
            y_hat: preds,
        },
        tgf,
    )
}
