use ndarray::{ArrayView1, ArrayView2};

use crate::grid::FittedTreeGrid;

/// Find the medoid index (grid with minimum sum of distances to all other grids).
/// Uses lambda distances (lambda_plus, lambda_minus) for distance computation.
/// Returns (medoid_index, pairwise_distances_matrix, sum_distances_per_grid).
pub fn find_medoid_index(grids: &[FittedTreeGrid]) -> (usize, Vec<Vec<f64>>, Vec<f64>) {
    let lambdas: Vec<(f64, f64)> = grids
        .iter()
        .map(|grid| (grid.lambda_plus, grid.lambda_minus))
        .collect();
    let pairwise_lambda_distances: Vec<Vec<f64>> = lambdas
        .iter()
        .map(|lambda| {
            lambdas
                .iter()
                .map(|lambda2| (lambda2.0 - lambda.0).powi(2) + (lambda2.1 - lambda.1).powi(2))
                .collect()
        })
        .collect();
    let sum_distances: Vec<f64> = pairwise_lambda_distances
        .iter()
        .map(|distances| distances.iter().sum())
        .collect();
    let _best_index_from_sum = sum_distances
        .iter()
        .enumerate()
        .min_by(|a, b| a.1.total_cmp(b.1))
        .unwrap()
        .0;

    let best_index = lambdas
        .iter()
        .enumerate()
        .min_by(|(_, lambda1), (_, lambda2)| {
            (lambda1.0 * lambda1.0 + lambda1.1 * lambda1.1)
                .total_cmp(&(lambda2.0 * lambda2.0 + lambda2.1 * lambda2.1))
        })
        .unwrap()
        .0;

    (best_index, pairwise_lambda_distances, sum_distances)
}

/// Compute pairwise cosine similarity between reference grid and all grids.
/// Returns vector of (backbone_similarity, tilt_similarity) pairs.
pub fn compute_pairwise_similarity_backbone_and_tilt(
    reference_grid: &FittedTreeGrid,
    grids: &[FittedTreeGrid],
    x: ArrayView2<f64>,
) -> Vec<(f64, f64)> {
    let (backbone_reference, tilt_reference) = reference_grid.predict_backbone_and_tilt(x);

    grids
        .iter()
        .map(|grid| {
            let (backbone_grid, tilt_grid) = grid.predict_backbone_and_tilt(x);
            (
                cosine_similarity(backbone_reference.view(), backbone_grid.view()),
                cosine_similarity(tilt_reference.view(), tilt_grid.view()),
            )
        })
        .collect()
}

/// Compute combined similarity score from backbone and tilt similarities.
#[allow(dead_code)] // May be used in future
pub fn similarity_backbone_and_tilt(
    backbone_a: ArrayView1<f64>,
    tilt_a: ArrayView1<f64>,
    backbone_b: ArrayView1<f64>,
    tilt_b: ArrayView1<f64>,
) -> f64 {
    (cosine_similarity(backbone_a, backbone_b) + 1.0) * (cosine_similarity(tilt_a, tilt_b) + 1.0)
        / 2.0
}

/// Compute cosine similarity between two vectors.
fn cosine_similarity(a: ArrayView1<f64>, b: ArrayView1<f64>) -> f64 {
    let dot_product = a.dot(&b);
    let a_norm = a.pow2().sum();
    let b_norm = b.pow2().sum();
    let denominator = (a_norm * b_norm).sqrt();
    if denominator == 0.0 {
        0.0
    } else {
        dot_product / denominator
    }
}
