"""
metric_tree_embedding.py

Embed a general metric (given as a CSV of points) into a dominating
tree metric, convert to a rooted binary tree, and produce a structure
compatible with the log-k DP on binary trees.

This file *does not* depend on the details of the DP; it only uses
TreeNode from bin_tree.py.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
from data import generate_random_facility_client_df

import numpy as np
import pandas as pd

# import local modules
from tree import TreeNode, print_tree, print_tree_node
from binary_tree import BinTreeNode, print_binary_tree, print_binary_tree_node
from vanilla_clustering import k_median_clustering, k_means_clustering
from mst_tree_embedding import build_mst_plus_assignment_tree
from tree_to_binary_tree_embedding import embed_tree_to_binary_tree, check_embedding_correctness

def _compute_distance_matrix(X_facilities: np.ndarray, X_clients: np.ndarray) -> np.ndarray:
    """
    Compute the pairwise Euclidean distance matrix for X.

    Parameters
    ----------
    X_facilities : np.ndarray of shape (n_f, d)
    X_clients : np.ndarray of shape (n_c, d)

    Returns
    -------
    D : np.ndarray of shape (n_f, n_c)
        Symmetric distance matrix with zeros on the diagonal.
    """
    # Broadcasting to (n, n, d) then collapsing
    diff = X_facilities[:, None, :] - X_clients[None, :, :]
    D = np.sqrt(np.sum(diff * diff, axis=2))
    return D

def build_binary_tree_from_csv(
    csv_path: str,
    facility_col: str,
    group_cols: Sequence[str],
    capacity_col: Optional[str] = None,
    *,
    facility_positive_value: float = 1.0,
    num_clusters : int = 2,
    distance_metric: str = "euclidean",
    clustering_method: str = "k-median",
    sanity_check: bool = False
    ):
    """
    Main entry point: read a CSV metric and produce an EmbeddedBinaryTree.

    Parameters
    ----------
    csv_path : str
        Path to the CSV file. Each row is a data point in R^d with all
        feature columns normalized to [0,1].
    facility_col : str
        Column that indicates whether a row is a facility or a client.
        By default, rows with value == facility_positive_value are
        treated as facilities, others as clients.
    group_cols : Sequence[str]
        Columns describing group membership for *facilities*.
        Each column should be 0/1 (or bool); membership vector for a
        facility is (row[group_cols[0]], ..., row[group_cols[t-1]]).
        Facilities can thus belong to multiple groups.
    capacity_col : str, optional
        Column name containing facility capacities. If None, all
        facilities get capacity 1.
    facility_positive_value : float, default=1.0
        Value in facility_col that marks a facility.
    distance_metric : {"euclidean"}, default="euclidean"
        Metric used for embedding (currently only Euclidean).
    num_clusters : int, default=2

    Returns
    -------
    EmbeddedBinaryTree
        Contains the root TreeNode, all nodes, row<->leaf maps and
        group_names (from group_cols).
    """
    if distance_metric != "euclidean":
        raise NotImplementedError("Only euclidean metric is implemented for now.")

    # --- 1. Load and parse the CSV ---

    df = pd.read_csv(csv_path, dtype=float, header=0, skipinitialspace=True)
    n = len(df)
    if n == 0:
        raise ValueError("CSV is empty")

    # Determine facilities vs clients
    if facility_col not in df.columns:
        raise ValueError(f"facility_col '{facility_col}' not found in CSV columns")

    facility_values = df[facility_col].to_numpy()
    is_facility = (facility_values == facility_positive_value)

    # Capacities (default 1)
    if capacity_col is not None:
        if capacity_col not in df.columns:
            raise ValueError(f"capacity_col '{capacity_col}' not found in CSV columns")
        capacity = df[capacity_col].to_numpy(dtype=float)
    else:
        capacity = np.ones(n, dtype=float)

    # Groups
    for col in group_cols:
        if col not in df.columns:
            raise ValueError(f"group column '{col}' not found in CSV columns")

    if group_cols:
        group_matrix = df[list(group_cols)].to_numpy(dtype=float)
        # Interpret non-zero as membership
        group_memberships = (group_matrix > 0.5).astype(int)
    else:
        group_memberships = np.zeros((n, 0), dtype=int)

    # --- 2. Extract feature matrix (drop role/group/capacity columns) ---

    ignore_cols = set(group_cols)
    ignore_cols.add(facility_col)
    if capacity_col is not None:
        ignore_cols.add(capacity_col)

    feature_cols = [c for c in df.columns if c not in ignore_cols]

    if not feature_cols:
        raise ValueError("No feature columns left after removing role/group/capacity columns.")

    # obtain rows of facilities only
    X_facilities = df[is_facility][feature_cols].to_numpy(dtype=float)
    X_clients = df[~is_facility][feature_cols].to_numpy(dtype=float)

    # --- 3a. Compute distance matrix ---
    D = _compute_distance_matrix(X_facilities, X_clients)

    # --- 3b. Compute k-median/k-means clustering of facilities ---
    if clustering_method == "k-median":
        centers, client_center_map, client_center_dist = k_median_clustering(D, num_clusters, 
                                                                              max_swaps=10, seed=0)
    elif clustering_method == "k-means":
        centers, client_center_map, client_center_dist = k_means_clustering(D, num_clusters,
                                                                             max_swaps=10, seed=0)
    else:
        raise ValueError(f"Unknown clustering_method '{clustering_method}'")


    # --- 3c. Build MST over facility centers ---
    facility_idx_in_df = np.where(is_facility)[0]
    client_idx_in_df   = np.where(~is_facility)[0]
    centers_idx_in_df  = facility_idx_in_df[centers]
    X_centers = df.iloc[centers_idx_in_df][feature_cols].to_numpy(dtype=float)

    tree_root, tree_nodes = build_mst_plus_assignment_tree(
                        df=df,
                        feature_cols=feature_cols,
                        is_facility=is_facility,
                        centers_idx_in_df=centers_idx_in_df,
                        group_cols=group_cols,        # e.g. ["grp_A", "grp_B", ...] or None
                        capacity_col="capacity_col",  # or None
                        )

    # --- 4. Embed that general tree into a binary tree ---
    binary_tree_root, binary_tree_nodes, binary_tree_orig_to_leaf_map = embed_tree_to_binary_tree(tree_root,
                                                                                         tree_nodes)

    # print(binary_tree_orig_to_leaf, type(binary_tree_orig_to_leaf))
    if sanity_check == True:
        check_embedding_correctness(tree_root, binary_tree_root, binary_tree_orig_to_leaf_map)

    return binary_tree_root, binary_tree_nodes

def build_binary_tree_from_df(
    df: pd.DataFrame,
    facility_col: str,
    group_cols: Sequence[str],
    capacity_col: Optional[str] = None,
    facility_positive_value: float = 1.0,
    num_clusters : int = 2,
    distance_metric: str = "euclidean",
    clustering_method: str = "k-median",
    max_swaps: int = 20,
    sanity_check: bool = False
    ):
    """
    Main entry point: read a CSV metric and produce an EmbeddedBinaryTree.

    Parameters
    ----------
    csv_path : str
        Path to the CSV file. Each row is a data point in R^d with all
        feature columns normalized to [0,1].
    facility_col : str
        Column that indicates whether a row is a facility or a client.
        By default, rows with value == facility_positive_value are
        treated as facilities, others as clients.
    group_cols : Sequence[str]
        Columns describing group membership for *facilities*.
        Each column should be 0/1 (or bool); membership vector for a
        facility is (row[group_cols[0]], ..., row[group_cols[t-1]]).
        Facilities can thus belong to multiple groups.
    capacity_col : str, optional
        Column name containing facility capacities. If None, all
        facilities get capacity 1.
    facility_positive_value : float, default=1.0
        Value in facility_col that marks a facility.
    distance_metric : {"euclidean"}, default="euclidean"
        Metric used for embedding (currently only Euclidean).
    num_clusters : int, default=2

    Returns
    -------
    EmbeddedBinaryTree
        Contains the root TreeNode, all nodes, row<->leaf maps and
        group_names (from group_cols).
    """
    if distance_metric != "euclidean":
        raise NotImplementedError("Only euclidean metric is implemented for now.")

    # --- 1. Load and parse the CSV ---

    # df = pd.read_csv(csv_path, dtype=float, header=0, skipinitialspace=True)
    n = len(df)
    if n == 0:
        raise ValueError("CSV is empty")

    # Determine facilities vs clients
    if facility_col not in df.columns:
        raise ValueError(f"facility_col '{facility_col}' not found in CSV columns")

    facility_values = df[facility_col].to_numpy()
    is_facility = (facility_values == facility_positive_value)

    # Capacities (default 1)
    if capacity_col is not None:
        if capacity_col not in df.columns:
            raise ValueError(f"capacity_col '{capacity_col}' not found in CSV columns")
        capacity = df[capacity_col].to_numpy(dtype=float)
    else:
        capacity = np.ones(n, dtype=float)

    # Groups
    for col in group_cols:
        if col not in df.columns:
            raise ValueError(f"group column '{col}' not found in CSV columns")

    if group_cols:
        group_matrix = df[list(group_cols)].to_numpy(dtype=float)
        # Interpret non-zero as membership
        group_memberships = (group_matrix > 0.5).astype(int)
    else:
        group_memberships = np.zeros((n, 0), dtype=int)

    # --- 2. Extract feature matrix (drop role/group/capacity columns) ---

    ignore_cols = set(group_cols)
    ignore_cols.add(facility_col)
    if capacity_col is not None:
        ignore_cols.add(capacity_col)

    feature_cols = [c for c in df.columns if c not in ignore_cols]

    if not feature_cols:
        raise ValueError("No feature columns left after removing role/group/capacity columns.")

    # obtain rows of facilities only
    X_facilities = df[is_facility][feature_cols].to_numpy(dtype=float)
    X_clients = df[~is_facility][feature_cols].to_numpy(dtype=float)

    # --- 3a. Compute distance matrix ---
    D = _compute_distance_matrix(X_facilities, X_clients)

    # --- 3b. Compute k-median/k-means clustering of facilities ---
    if clustering_method == "k-median":
        centers, client_center_map, client_center_dist = k_median_clustering(D, num_clusters, 
                                                                              max_swaps, seed=0)
    elif clustering_method == "k-means":
        centers, client_center_map, client_center_dist = k_means_clustering(D, num_clusters,
                                                                             max_swaps, seed=0)
    else:
        raise ValueError(f"Unknown clustering_method '{clustering_method}'")


    # --- 3c. Build MST over facility centers ---
    facility_idx_in_df = np.where(is_facility)[0]
    client_idx_in_df   = np.where(~is_facility)[0]
    centers_idx_in_df  = facility_idx_in_df[centers]
    X_centers = df.iloc[centers_idx_in_df][feature_cols].to_numpy(dtype=float)

    tree_root, tree_nodes = build_mst_plus_assignment_tree(
                        df=df,
                        feature_cols=feature_cols,
                        is_facility=is_facility,
                        centers_idx_in_df=centers_idx_in_df,
                        group_cols=group_cols,        # e.g. ["grp_A", "grp_B", ...] or None
                        capacity_col="capacity_col",  # or None
                        )

    # --- 4. Embed that general tree into a binary tree ---
    binary_tree_root, binary_tree_nodes, binary_tree_orig_to_leaf_map = embed_tree_to_binary_tree(tree_root,
                                                                                         tree_nodes)

    # print(binary_tree_orig_to_leaf, type(binary_tree_orig_to_leaf))
    if sanity_check == True:
        check_embedding_correctness(tree_root, binary_tree_root, binary_tree_orig_to_leaf_map)

    return binary_tree_root, binary_tree_nodes, binary_tree_orig_to_leaf_map


def main():
    # Example usage
    binary_tree_root, binary_tree_nodes = build_binary_tree_from_csv(
            csv_path="../data/toy_data.csv",
            facility_col="is_facility",
            group_cols=["group1", "group2"],
            capacity_col="capacity",
            facility_positive_value=1.0,
            num_clusters=3,
            distance_metric="euclidean",
            clustering_method="k-median",
            sanity_check=True
        )

    binary_tree_root, binary_tree_nodes = build_binary_tree_from_csv(
            csv_path="../data/toy_data_2.csv",
            facility_col="is_facility",
            group_cols=["group1", "group2"],
            capacity_col="capacity",
            facility_positive_value=1.0,
            num_clusters=3,
            distance_metric="euclidean",
            clustering_method="k-median",
            sanity_check=True
        )

    df = generate_random_facility_client_df(
            n_points=20,
            n_features=3,
            n_groups=2,
            facility_probability=0.3,
            max_capacity=int(20/3),
            seed=123456789
        )
    binary_tree_root, binary_tree_nodes = build_binary_tree_from_df(
            df=df,
            facility_col="is_facility",
            group_cols=["group1", "group2"],
            capacity_col="capacity",
            facility_positive_value=1.0,
            num_clusters=3,
            distance_metric="euclidean",
            clustering_method="k-median",
            sanity_check=True
        )

if  __name__ == "__main__":
    main()
