import numpy as np
import os
import json
from Bio import Phylo
from io import StringIO

def get_leaves(tree, node):
    '''Extract leaves under any node'''
    if node >= 0:
        return [node]
    if node not in tree:
        return []
    result = []
    for child in tree[node]:
        result.extend(get_leaves(tree, child))
    return result


def count_balances(tree):
    '''Count total balances (should equal D-1)'''
    return sum(len(children) - 1 for children in tree.values())

def get_tree_statistics(tree):
    '''Get branching statistics'''
    branching = [len(children) for children in tree.values()]
    return {
        'n_internal': len(tree),
        'n_binary': sum(1 for k in branching if k == 2),
        'n_polytomy': sum(1 for k in branching if k > 2),
        'max_k': max(branching),
        'mean_k': np.mean(branching)
    }

def save_tree(tree, root, edge_lengths, path, metadata=None):
    '''Save tree to directory'''
    os.makedirs(path, exist_ok=True)
    
    # Convert numpy int64 to regular int for JSON serialization
    tree_json = {str(k): [int(c) for c in v] for k, v in tree.items()}
    with open(os.path.join(path, 'tree.json'), 'w') as f:
        json.dump(tree_json, f)
    
    with open(os.path.join(path, 'root.txt'), 'w') as f:
        f.write(str(root))
    
    el_json = {f"{p},{c}": l for (p, c), l in edge_lengths.items()}
    with open(os.path.join(path, 'edge_lengths.json'), 'w') as f:
        json.dump(el_json, f)
    
    if metadata:
        with open(os.path.join(path, 'metadata.json'), 'w') as f:
            json.dump(metadata, f, indent=2)

def load_tree(path):
    '''Load tree from directory'''
    with open(os.path.join(path, 'tree.json'), 'r') as f:
        tree = {int(k): v for k, v in json.load(f).items()}
    
    with open(os.path.join(path, 'root.txt'), 'r') as f:
        root = int(f.read().strip())
    
    with open(os.path.join(path, 'edge_lengths.json'), 'r') as f:
        el_json = json.load(f)
    edge_lengths = {(int(k.split(',')[0]), int(k.split(',')[1])): v for k, v in el_json.items()}
    
    metadata = None
    meta_path = os.path.join(path, 'metadata.json')
    if os.path.exists(meta_path):
        with open(meta_path, 'r') as f:
            metadata = json.load(f)
    
    return tree, root, edge_lengths, metadata

def force_binary_random(tree, root, edge_lengths, rng):
    binary_tree = {}
    binary_edge_lengths = {}
    node_id = min(tree.keys()) - 1

    def new_id():
        nonlocal node_id
        node_id -= 1
        return node_id

    def resolve(node):
        if node >= 0:
            return node

        children = tree[node]
        resolved = [resolve(c) for c in children]

        # Single child: collapse (pass through)
        if len(resolved) == 1:
            return resolved[0]

        if len(resolved) == 2:
            binary_tree[node] = resolved
            for c in resolved:
                binary_edge_lengths[(node, c)] = edge_lengths.get((node, c), 1.0)
            return node

        # polytomy: randomly chain-resolve
        shuffled = list(resolved)
        rng.shuffle(shuffled)

        while len(shuffled) > 2:
            a = shuffled.pop(0)
            b = shuffled.pop(0)
            nid = new_id()
            binary_tree[nid] = [a, b]
            binary_edge_lengths[(nid, a)] = 1.0
            binary_edge_lengths[(nid, b)] = 1.0
            shuffled.insert(0, nid)

        a, b = shuffled
        binary_tree[node] = [a, b]
        binary_edge_lengths[(node, a)] = edge_lengths.get((node, a), 1.0)
        binary_edge_lengths[(node, b)] = edge_lengths.get((node, b), 1.0)

        return node

    new_root = resolve(root)
    return binary_tree, new_root, binary_edge_lengths

def phylo_to_tree(T, otu_index):
    """
    Convert a Bio.Phylo tree `T` into the internal H-ILR tree format:
      * leaves = 0..D-1 in the order of otu_index
      * internal nodes = -1, -2, -3, ...
      * returns: (tree, root, edge_lengths)
    """
    # map leaf names → indices in otu order
    leaf_id = {name: i for i, name in enumerate(otu_index)}

    tree = {}
    edge_lengths = {}
    node_map = {}     # Bio.Phylo clade → H-ILR node id
    next_internal = -1

    # assign IDs
    def assign_id(clade):
        nonlocal next_internal
        if clade.is_terminal():
            return leaf_id[clade.name]
        node_map[clade] = next_internal
        next_internal -= 1
        return node_map[clade]

    # create all IDs
    def visit_assign(clade):
        cid = assign_id(clade)
        for child in clade.clades:
            visit_assign(child)
        return cid

    root = visit_assign(T.root)

    # build adjacency + edge lengths
    def visit_edges(clade):
        if clade.is_terminal():
            return
        parent = node_map[clade]
        children = []
        for child in clade.clades:
            cid = leaf_id[child.name] if child.is_terminal() else node_map[child]
            children.append(cid)
            edge_lengths[(parent, cid)] = child.branch_length if child.branch_length else 1.0
            visit_edges(child)
        tree[parent] = children

    visit_edges(T.root)
    return tree, root, edge_lengths