import math
import time
from typing import Sequence, Optional, Dict, Any, Tuple

import numpy as np
import pandas as pd

from vanilla_clustering import k_median_clustering, k_means_clustering
from mst_tree_embedding import _compute_distance_matrix_two_sets, build_mst_plus_assignment_tree
from tree_to_binary_tree_embedding import embed_tree_to_binary_tree
from binary_tree import binary_tree_pairwise_distances

from data import generate_random_facility_client_df

def compare_original_vs_binary_tree_distances(
    df: pd.DataFrame,
    num_clusters: int,
    clustering_method: str = "k-median",
    max_swaps: int = 10,
    feature_cols: Optional[Sequence[str]] = None,
    group_cols: Optional[Sequence[str]] = None,
    is_facility_col: str = "is_facility",
    capacity_col: Optional[str] = "capacity",
    local_search_seed: int = 123456789,
    compute_stats: bool = False,
) -> Tuple:
    """
    Build a binary-tree embedding starting from a data frame by:

      1. Restricting candidate centers to existing facilities.
      2. Running k-median or k-means over the facility-to-client distance
         matrix to select `num_clusters` centers.
      3. Building an MST+assignment tree over all points using these
         centers as the tree facilities.
      4. Embedding that tree into a binary tree.
      5. Comparing the original Euclidean metric with the binary-tree metric.

    Returns
    -------
    binary_tree_root : BinTreeNode
    binary_tree_nodes : Dict[int, BinTreeNode]
    tree_to_bintree_idx : Dict[int, int]
        Maps original node_number (row index) to the binary-tree node_number.
    stats : Dict[str, Any]
        Distortion statistics:
          - max_abs_diff
          - mean_signed_diff
          - std_signed_diff
          - mean_original_distance
          - mean_binary_tree_distance
          - difference_of_means
    """
    # --- 0. Column defaults -------------------------------------------------
    start_time = time.time()

    if feature_cols is None:
        feature_cols = [c for c in df.columns if c.startswith("f")]
    if group_cols is None:
        group_cols = [c for c in df.columns if c.startswith("group")]

    # We assume the DataFrame index is 0..n-1 in order.
    df = df.reset_index(drop=True)

    n = len(df)
    if n <= 1:
        raise ValueError("Need at least two points to build an embedding.")

    # is_facility as boolean mask
    is_facility = df[is_facility_col].to_numpy(dtype=bool)

    # --- 3. Clustering over facilities only --------------------------------
    X_facilities = df[is_facility][feature_cols].to_numpy(dtype=float)
    X_clients = df[~is_facility][feature_cols].to_numpy(dtype=float)

    if X_facilities.shape[0] == 0:
        raise ValueError("No facilities available (is_facility_col has no 1s).")

    if X_clients.shape[0] == 0:
        # If there are no clients, fall back to clustering facilities vs themselves.
        X_clients = X_facilities

    # 3a. facility-to-client distance matrix
    D = _compute_distance_matrix_two_sets(X_facilities, X_clients)

    # 3b. k-median / k-means clustering
    if clustering_method == "k-median":
        centers, client_center_map, client_center_dist = k_median_clustering(
            D, num_clusters, max_swaps=max_swaps, seed=local_search_seed
        )
    elif clustering_method == "k-means":
        centers, client_center_map, client_center_dist = k_means_clustering(
            D, num_clusters, max_swaps=max_swaps, seed=local_search_seed
        )
    else:
        raise ValueError(f"Unknown clustering_method '{clustering_method}'")

    # 3c. Map center indices back to the original df rows
    facility_idx_in_df = np.where(is_facility)[0]
    centers_idx_in_df = facility_idx_in_df[centers]

    # --- 4. Build MST+assignment tree using the chosen centers -------------
    tree_root, tree_nodes = build_mst_plus_assignment_tree(
        df=df,
        feature_cols=feature_cols,
        is_facility=is_facility,
        centers_idx_in_df=centers_idx_in_df,
        group_cols=group_cols,
        capacity_col=capacity_col,
    )

    # --- 5. Embed the general tree into a binary tree ----------------------
    # This uses your updated embed_tree_to_binary_tree (minimal dummies, etc.).
    binary_tree_root, binary_tree_nodes, tree_to_bintree_idx = embed_tree_to_binary_tree(
        tree_root, tree_nodes
    )
    embedding_time = time.time() - start_time

    if compute_stats == True:
        # --- 6. Compute distortion statistics ----------------------------------
        # 6a. Original Euclidean distances on feature space
        X_all = df[list(feature_cols)].to_numpy(dtype=float)
        D_orig = _compute_distance_matrix_two_sets(X_all, X_all)  # shape (n, n)

        # 6b. Binary-tree pairwise distances
        bin_node_ids, bin_dist_matrix = binary_tree_pairwise_distances(binary_tree_root)
        # Map binary-tree node_number -> row index in the matrix
        bin_id_to_idx: Dict[int, int] = {nid: i for i, nid in enumerate(bin_node_ids)}

        # 6c. For each pair of original points, compare distances
        max_abs_diff = 0.0
        sum_diff = 0.0
        sum_abs_diff = 0.0
        sum_diff_sq = 0.0
        sum_orig = 0.0
        sum_bin = 0.0
        num_pairs = 0

        for i in range(n):
            for j in range(i + 1, n):
                d_orig = float(D_orig[i, j])

                # Map original row indices to their representatives in the binary tree
                bnode_i = tree_to_bintree_idx[i]
                bnode_j = tree_to_bintree_idx[j]
                bi = bin_id_to_idx[bnode_i]
                bj = bin_id_to_idx[bnode_j]
                d_bin = float(bin_dist_matrix[bi][bj])

                diff = d_bin - d_orig
                abs_diff = abs(diff)

                num_pairs += 1
                sum_diff += diff
                sum_abs_diff += abs_diff
                sum_diff_sq += diff * diff
                sum_orig += d_orig
                sum_bin += d_bin

                if abs_diff > max_abs_diff:
                    max_abs_diff = abs_diff

        if num_pairs > 0:
            mean_signed_diff = sum_diff / num_pairs
            mean_abs_diff = sum_abs_diff / num_pairs
            mean_orig = sum_orig / num_pairs
            mean_bin = sum_bin / num_pairs
            variance = max(sum_diff_sq / num_pairs - mean_signed_diff ** 2, 0.0)
            std_signed_diff = math.sqrt(variance)
        else:
            mean_signed_diff = 0.0
            mean_orig = 0.0
            mean_bin = 0.0
            std_signed_diff = 0.0

        stats: Dict[str, Any] = {
            "num_points": n,
            "num_nodes_in_binary_tree": len(binary_tree_nodes),
            "num_pairs": num_pairs,
            "max_abs_diff": max_abs_diff,
            "mean_orig_dist": mean_orig,
            "mean_bin_dist": mean_bin,
            "mean_signed_diff": mean_signed_diff,
            "mean_abs_diff": mean_abs_diff,
            "std_signed_diff": std_signed_diff,
            "mean_original_distance": mean_orig,
            "mean_binary_tree_distance": mean_bin,
            "difference_of_means": mean_bin - mean_orig,
            "embedding_time": embedding_time,
        }

        return binary_tree_root, binary_tree_nodes, tree_to_bintree_idx, stats
    else:   
        return binary_tree_root, binary_tree_nodes, tree_to_bintree_idx

def test_stub():
    df = generate_random_facility_client_df(
        n_points=100,
        n_features=3,
        n_groups=2,
        facility_probability=0.5,
        max_capacity=5,
        seed=42,
    )

    stats = compare_original_vs_binary_tree_distances(
                        df,
                        num_clusters=5,
                        clustering_method="k-median",
                        max_swaps=10,
                        feature_cols=["f1", "f2", "f3"],
                        group_cols=["group1", "group2"],
                        is_facility_col="is_facility",
                        capacity_col="capacity"
                    )
    print(stats)


if __name__ == "__main__":
    test_stub()
