import numpy as np
from tree_utils import get_leaves


def construct_helmert_for_clade_sizes(k, clade_sizes):
    '''
    Construct Helmert matrix satisfying:
    (A) sum_r w_r H_{r,m1} H_{r,m2} = delta_{m1,m2}  with w_r = 1/n_r
    (B) sum_r H_{r,m} = 0
    
    Uses standard Helmert structure: column m compares clade (m+1) vs mean of clades 1..m
    '''
    n = np.array(clade_sizes, dtype=float)
    w = 1.0 / n
    
    # Build Helmert basis of S = {h : sum_r h_r = 0} (dimension k-1)
    # Column m: clades 1..m get +1, clade m+1 gets -(m+1), rest get 0
    basis = np.zeros((k, k-1))
    for m in range(k-1):
        basis[:m+1, m] = 1.0
        basis[m+1, m] = -(m+1)
    
    # Gram-Schmidt under w-inner product
    H = np.zeros((k, k-1))
    for m in range(k-1):
        v = basis[:, m].copy()
        for j in range(m):
            proj = np.sum(w * v * H[:, j])
            v = v - proj * H[:, j]
        norm = np.sqrt(np.sum(w * v * v))
        H[:, m] = v / norm
    
    return H


def construct_V_unweighted(tree, root, D, return_node_info=False):
    '''
    Construct unweighted PolyILR contrast matrix V in R^{D x (D-1)}.
    
    Children are processed in their original order (as stored in tree).
    Leaf-level contrast: v_i = H_{r,m} / n_r for i in C_r
    '''
    contrast_vectors = []
    node_info = {} if return_node_info else None
    
    def process_node(node):
        if node >= 0:
            return
        
        children = tree[node]
        k = len(children)
        child_leaf_sets = [get_leaves(tree, c) for c in children]
        clade_sizes = [len(s) for s in child_leaf_sets]
        
        H = construct_helmert_for_clade_sizes(k, clade_sizes)
        
        if return_node_info:
            node_info[node] = {
                'H': H,
                'children': children,
                'clade_sizes': clade_sizes,
                'order': list(range(k))
            }
        
        # Leaf-level contrast: v_i = H_{r,m} / n_r
        for m in range(k-1):
            v = np.zeros(D)
            for r, leaf_set in enumerate(child_leaf_sets):
                n_r = len(leaf_set)
                val = H[r, m] / n_r
                for leaf in leaf_set:
                    v[leaf] = val
            contrast_vectors.append(v)
        
        for c in children:
            process_node(c)
    
    process_node(root)
    V = np.column_stack(contrast_vectors)
    
    if return_node_info:
        return V, node_info
    return V


def construct_V_weighted(tree, root, D, edge_lengths, return_node_info=False):
    '''
    Construct weighted PolyILR contrast matrix V in R^{D x (D-1)}.
    
    At each polytomy, children are sorted by descending branch length.
    This aligns primary coordinates with deepest evolutionary splits.
    '''
    contrast_vectors = []
    node_info = {} if return_node_info else None
    
    def process_node(node):
        if node >= 0:
            return
        
        children = tree[node]
        k = len(children)
        
        # Get branch lengths for each child
        branch_lens = [edge_lengths.get((node, c), 1.0) for c in children]
        
        # Sort children by descending branch length
        sorted_indices = np.argsort(branch_lens)[::-1]
        children_sorted = [children[i] for i in sorted_indices]
        branch_lens_sorted = [branch_lens[i] for i in sorted_indices]
        
        child_leaf_sets = [get_leaves(tree, c) for c in children_sorted]
        clade_sizes = [len(s) for s in child_leaf_sets]
        
        H = construct_helmert_for_clade_sizes(k, clade_sizes)
        
        if return_node_info:
            node_info[node] = {
                'H': H,
                'children': children_sorted,
                'children_original': children,
                'clade_sizes': clade_sizes,
                'branch_lengths': branch_lens_sorted,
                'sort_order': sorted_indices.tolist()
            }
        
        # Leaf-level contrast: v_i = H_{r,m} / n_r
        for m in range(k-1):
            v = np.zeros(D)
            for r, leaf_set in enumerate(child_leaf_sets):
                n_r = len(leaf_set)
                val = H[r, m] / n_r
                for leaf in leaf_set:
                    v[leaf] = val
            contrast_vectors.append(v)
        
        # Recurse in sorted order
        for c in children_sorted:
            process_node(c)
    
    process_node(root)
    V = np.column_stack(contrast_vectors)
    
    if return_node_info:
        return V, node_info
    return V


def construct_V(tree, root, D, edge_lengths=None, return_node_info=False):
    '''
    Construct PolyILR contrast matrix V.
    
    If edge_lengths is None: unweighted (arbitrary child ordering)
    If edge_lengths provided: weighted (children sorted by descending branch length)
    '''
    if edge_lengths is None:
        return construct_V_unweighted(tree, root, D, return_node_info=return_node_info)
    else:
        return construct_V_weighted(tree, root, D, edge_lengths, return_node_info=return_node_info)


def construct_V_with_mapping(tree, root, D, edge_lengths):
    """
    Construct V matrix and return mapping from column index to (node, contrast_index, children, clade_sizes).
    """
    contrast_vectors = []
    column_mapping = []
    
    def process_node(node):
        if node >= 0:
            return
        
        children = tree[node]
        k = len(children)
        
        # Get branch lengths and sort by descending
        branch_lens = [edge_lengths.get((node, c), 1.0) for c in children]
        sorted_indices = np.argsort(branch_lens)[::-1]
        children_sorted = [children[i] for i in sorted_indices]
        
        child_leaf_sets = [get_leaves(tree, c) for c in children_sorted]
        clade_sizes = [len(s) for s in child_leaf_sets]
        
        H = construct_helmert_for_clade_sizes(k, clade_sizes)
        
        # Build contrast vectors and record mapping
        for m in range(k - 1):
            v = np.zeros(D)
            for r, leaf_set in enumerate(child_leaf_sets):
                n_r = len(leaf_set)
                val = H[r, m] / n_r
                for leaf in leaf_set:
                    v[leaf] = val
            contrast_vectors.append(v)
            
            column_mapping.append({
                'node': node,
                'contrast_idx': m + 1,  # 1-indexed
                'n_contrasts': k - 1,
                'children': children_sorted,
                'clade_sizes': clade_sizes,
                'child_leaves': child_leaf_sets,
            })
        
        for c in children_sorted:
            process_node(c)
    
    process_node(root)
    V = np.column_stack(contrast_vectors)
    
    return V, column_mapping


def aitchison_distance(x, y):
    '''Aitchison distance between two compositions'''
    clr_x = np.log(x) - np.mean(np.log(x))
    clr_y = np.log(y) - np.mean(np.log(y))
    return np.linalg.norm(clr_x - clr_y)


def ilr_transform(p, V):
    '''Transform composition(s) to ILR coordinates'''
    if p.ndim == 1:
        return V.T @ np.log(p)
    else:
        return (V.T @ np.log(p).T).T


def ilr_inverse(y, V):
    '''Transform ILR coordinates back to simplex'''
    if y.ndim == 1:
        z = V @ y
        exp_z = np.exp(z - z.max())
        return exp_z / np.sum(exp_z)
    else:
        z = (V @ y.T).T
        exp_z = np.exp(z - z.max(axis=1, keepdims=True))
        return exp_z / np.sum(exp_z, axis=1, keepdims=True)