import numpy as np
import pandas as pd
from typing import Sequence, Optional, Dict, Tuple, List

from tree import TreeNode   # <- use the TreeNode you defined earlier


def _compute_distance_matrix_two_sets(X1: np.ndarray, X2: np.ndarray) -> np.ndarray:
    """
    Compute pairwise Euclidean distances between two point sets.

    Parameters
    ----------
    X1 : array (n1, d)
    X2 : array (n2, d)

    Returns
    -------
    D : array (n1, n2)
        D[i, j] = ||X1[i] - X2[j]||_2
    """
    diff = X1[:, None, :] - X2[None, :, :]
    D = np.sqrt(np.sum(diff * diff, axis=2))
    return D


def _prim_mst(D: np.ndarray) -> List[Tuple[int, int, float]]:
    """
    Prim's algorithm for MST on a complete graph with distance matrix D.

    Parameters
    ----------
    D : np.ndarray of shape (k, k)

    Returns
    -------
    edges : list of (u, v, w)
        MST edges using local indices 0..k-1.
    """
    k = D.shape[0]
    in_tree = np.zeros(k, dtype=bool)
    parent = np.full(k, -1, dtype=int)
    key = np.full(k, np.inf, dtype=float)

    # start from 0
    key[0] = 0.0
    for _ in range(k):
        u = int(np.argmin(np.where(in_tree, np.inf, key)))
        in_tree[u] = True
        for v in range(k):
            if not in_tree[v] and D[u, v] < key[v]:
                key[v] = D[u, v]
                parent[v] = u

    edges: List[Tuple[int, int, float]] = []
    for v in range(1, k):
        u = parent[v]
        edges.append((u, v, float(D[u, v])))
    return edges


def build_mst_plus_assignment_tree(
    df: pd.DataFrame,
    feature_cols: Sequence[str],
    is_facility: np.ndarray,
    centers_idx_in_df: np.ndarray,
    group_cols: Optional[Sequence[str]] = None,
    capacity_col: Optional[str] = None,
) -> Tuple[TreeNode, Dict[int, TreeNode]]:
    """
    Build a tree metric where:

      1. The chosen centers (rows centers_idx_in_df) form a minimum
         spanning tree w.r.t. Euclidean distance on feature_cols.
      2. Every other data point (client or non-center facility) is
         attached to its nearest center (in that same metric).

    The output is a rooted tree using the TreeNode structure from tree_v2.py,
    with node_number == row index in df.

    Parameters
    ----------
    df : pd.DataFrame
        Data containing all points.
    feature_cols : sequence of str
        Columns used as coordinates for distance computations.
        (Already normalized to [0,1] as you specified.)
    is_facility : np.ndarray of shape (n,)
        Boolean array (same order as df) indicating which rows are facilities.
    centers_idx_in_df : np.ndarray of shape (k,)
        Row indices (in df) of the facilities that have been chosen as
        k-median / k-means centers.
    group_cols : sequence of str, optional
        Columns describing group membership; if given, we attach them to
        facility leaves as a 0/1 vector.
    capacity_col : str, optional
        Column for facility capacities. If None, facilities get capacity 1,
        clients get capacity 0.

    Returns
    -------
    root : TreeNode
        Root of the constructed tree (a center; we pick centers_idx_in_df[0]).
    nodes : dict
        Mapping row_index -> TreeNode for all rows in df.
    """
    n = len(df)
    centers_idx_in_df = np.asarray(centers_idx_in_df, dtype=int)
    if np.unique(centers_idx_in_df).size != centers_idx_in_df.size:
        raise ValueError("centers_idx_in_df must contain unique row indices")

    # --- 1. Prepare coordinates ---

    X_all = df[feature_cols].to_numpy(dtype=float)
    X_centers = X_all[centers_idx_in_df]      # shape (k, d)
    k = X_centers.shape[0]
    if k == 0:
        raise ValueError("No centers provided")

    # --- 2. Build MST over the centers only ---

    D_centers = _compute_distance_matrix_two_sets(X_centers, X_centers)  # (k, k)
    mst_edges = _prim_mst(D_centers)  # edges in local center indices 0..k-1

    # adjacency on all n nodes (we will add center-center + point-center edges)
    adj: Dict[int, List[Tuple[int, float]]] = {i: [] for i in range(n)}

    # add MST edges among centers (convert local indices to df row indices)
    for u_loc, v_loc, w in mst_edges:
        u = int(centers_idx_in_df[u_loc])
        v = int(centers_idx_in_df[v_loc])
        adj[u].append((v, w))
        adj[v].append((u, w))

    # --- 3. Attach every non-center point to its nearest center ---

    all_idx = np.arange(n, dtype=int)
    others_idx_in_df = np.setdiff1d(all_idx, centers_idx_in_df, assume_unique=True)

    if others_idx_in_df.size > 0:
        X_others = X_all[others_idx_in_df]  # shape (m, d)
        D_co = _compute_distance_matrix_two_sets(X_centers, X_others)  # (k, m)

        for pos, row in enumerate(others_idx_in_df):
            # nearest center in local center index space
            c_loc = int(np.argmin(D_co[:, pos]))
            center_row = int(centers_idx_in_df[c_loc])
            dist = float(D_co[c_loc, pos])

            # add undirected edge between row and its center
            adj[center_row].append((row, dist))
            adj[row].append((center_row, dist))

    # At this point, adj is the adjacency of a tree whose:
    #   - internal structure on centers is the MST,
    #   - all other points are leaves attached to their closest center.

    # --- 4. Root the tree at one center and orient edges ---

    root_row = int(centers_idx_in_df[0])   # choose first center as root
    children: Dict[int, List[Tuple[int, float]]] = {i: [] for i in range(n)}
    parent: Dict[int, Optional[int]] = {root_row: None}
    parent_dist: Dict[int, float] = {root_row: 0.0}

    stack = [root_row]
    while stack:
        u = stack.pop()
        for v, w in adj[u]:
            if v in parent:
                continue
            parent[v] = u
            parent_dist[v] = w
            children[u].append((v, w))
            stack.append(v)

    # --- 5. Build TreeNode objects and attach structure ---

    nodes: Dict[int, TreeNode] = {}
    for row in range(n):
        tn = TreeNode(node_number=row)   # node_number == df row index
        nodes[row] = tn

    # attach parent/children pointers and edge lengths
    for u, ch_list in children.items():
        u_node = nodes[u]
        for v, w in ch_list:
            v_node = nodes[v]
            u_node.children.append(v_node)
            u_node.children_distances.append(float(w))
            v_node.parent = u_node
            v_node.parent_id = u_node.node_number
            v_node.parent_distance = float(w)

    # --- 6. Mark leaves and fill facility/client, capacity, group info ---

    # capacities
    if capacity_col is not None and capacity_col in df.columns:
        capacities = df[capacity_col].to_numpy(dtype=float)
    else:
        # default: capacity 1 for facilities, 0 for clients
        capacities = np.ones(n, dtype=float)
        capacities[~is_facility] = 0.0

    # group memberships
    if group_cols:
        groups = df[list(group_cols)].to_numpy(dtype=float)
        groups = (groups > 0.5).astype(int)
        t = groups.shape[1]
    else:
        groups = None
        t = 0

    for row, tn in nodes.items():
        tn.is_leaf = (len(tn.children) == 0)

        if is_facility[row]:
            tn.is_client = False
            tn.capacity = int(capacities[row])

            if groups is not None:
                vec = tuple(int(v) for v in groups[row])
            else:
                vec = tuple()
            tn.facility_type_vector = vec

            # For debugging: store first nonzero group index + 1, or -1.
            if groups is not None and any(groups[row]):
                tn.facility_type = int(np.argmax(groups[row]) + 1)
            else:
                tn.facility_type = -1
        else:
            tn.is_client = True
            tn.capacity = 0
            tn.facility_type = -1
            if t > 0:
                tn.facility_type_vector = tuple(0 for _ in range(t))
            else:
                tn.facility_type_vector = tuple()

    root = nodes[root_row]
    return root, nodes
