import os
import jax
import json
import model
import numpy as np
from flax import nnx
import jax.numpy as jnp
from snapshot import Snapshot
from itertools import combinations
from exp_utils import TreeGenerator, TreeTokenizer


os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'


def safe_index(lst, item):
    try:
        return lst.index(item)
    except ValueError:
        return None


def find_children(tree_strings):
    """ Find children of nodes in the [ROUTE] for each tree string.
    :param tree_strings: Either a single tree string or list of tree strings

    :return:
        - if single string: List of children strings for each node in route (except last)
        - if batch: List of lists, where each inner list contains children strings for that tree
    """
    # Handle single string input
    if isinstance(tree_strings, str):
        tree_strings = [tree_strings]
        single_input = True
    else:
        single_input = False

    results = []

    for tree_string in tree_strings:
        # Parse the tree string to extract route and parent-child pairs
        parts = tree_string.split('[ROUTE]')
        if len(parts) != 2:
            results.append([])
            continue
        route = parts[1]

        # Extract parent-child pairs from the part before [ROOT]
        before_root = parts[0].split('[ROOT]')[0]

        # Parse parent-child pairs
        parent_child_map = {}  # parent -> list of children
        if before_root:  # Check if there are any pairs
            pairs = before_root.split(',')
            for pair in pairs:
                if len(pair) == 2:  # Valid parent-child pair
                    parent, child = pair[0], pair[1]
                    if parent not in parent_child_map:
                        parent_child_map[parent] = []
                    parent_child_map[parent].append(child)

        # Find children for each node in route (except the last one)
        route_children = []
        for i in range(len(route) - 1):  # Exclude last node
            node = route[i]
            if node in parent_child_map:
                # Sort children: route children first, then non-route children
                route_set = set(route)
                children_list = parent_child_map[node]
                route_children_list = [child for child in children_list if child in route_set]
                non_route_children_list = [child for child in children_list if child not in route_set]
                # Sort each group alphabetically, then combine
                sorted_route_children = sorted(route_children_list)
                sorted_non_route_children = sorted(non_route_children_list)
                children = ''.join(sorted_route_children + sorted_non_route_children)
                route_children.append(children)
            else:
                route_children.append('')  # No children found
        
        results.append(route_children)
    
    # Return format: single list for single input, list of lists for batch input
    if single_input:
        return results[0]
    else:
        return results


def decompose(latent, vocab):
    """ Find latent mixture recipe
    :param latent: latent vector waiting to be decomposed, shape: (batch, latent_dims)
    :param vocab: vocabulary latents, shape: (vocab_size, latent_dims)
    """
    batch_size, latent_dims = latent.shape
    vocab_size = vocab.shape[0]  # Should be 26
    
    # Get all combinations of 2 vocabulary vectors
    vocab_pairs = jnp.array(list(combinations(range(vocab_size), 2)))  # Shape: (325, 2)
    num_combinations = vocab_pairs.shape[0]  # C(26,2) = 325
    
    # Vectorized computation for all combinations at once
    # Extract vocabulary vector pairs: shape (num_combinations, 2, latent_dims)
    v1 = vocab[vocab_pairs[:, 0]]  # Shape: (325, latent_dims)
    v2 = vocab[vocab_pairs[:, 1]]  # Shape: (325, latent_dims)
    
    # Stack to create matrices A for each combination: (325, latent_dims, 2)
    A_matrices = jnp.stack([v1, v2], axis=-1)  # Shape: (325, latent_dims, 2)
    
    # Initialize output arrays
    distances = jnp.zeros((batch_size, num_combinations))
    orders = jnp.zeros((batch_size, num_combinations, 2), dtype=jnp.int32)
    
    # Process each combination separately to avoid complex einsum
    for comb_idx in range(num_combinations):
        A = A_matrices[comb_idx]  # Shape: (latent_dims, 2)
        
        # Solve for coefficients for all batch items at once
        # A^T @ A (2, 2) and A^T @ latent (2, batch)
        ATA = A.T @ A  # (2, 2)
        ATb = A.T @ latent.T  # (2, batch)

        ATA = np.array(ATA.astype(jnp.float32), dtype=np.float32)
        ATb = np.array(ATb.astype(jnp.float32), dtype=np.float32)
        # Solve normal equations for all batch items
        coeffs = np.linalg.solve(ATA, ATb).T  # (batch, 2)
        
        # Clip negative coefficients to 0 for non-negative constraint  
        coeffs = np.maximum(coeffs, 0)
        
        # Calculate projections: batch_latent = coeffs @ A.T
        projections = coeffs @ A.T  # (batch, latent_dims)
        
        # Calculate distances
        batch_distances = jnp.linalg.norm(latent - projections, axis=1)  # (batch,)
        distances = distances.at[:, comb_idx].set(batch_distances)
        
        # Determine ordering based on coefficient magnitudes
        i, j = vocab_pairs[comb_idx]
        alpha_greater = coeffs[:, 0] > coeffs[:, 1]  # (batch,)
        
        # Set orders for this combination
        comb_orders = jnp.where(
            alpha_greater[:, None],
            jnp.array([i, j])[None, :],  # [i, j] when alpha > beta
            jnp.array([j, i])[None, :]   # [j, i] when beta > alpha
        )  # (batch, 2)
        orders = orders.at[:, comb_idx, :].set(comb_orders)
    
    # Sort by distance for each batch
    sorted_indices = jnp.argsort(distances, axis=1)  # (batch, 325)
    
    # Apply sorting to both distances and orders
    sorted_distances = jnp.take_along_axis(distances, sorted_indices, axis=1)
    sorted_orders = jnp.take_along_axis(orders, sorted_indices[:, :, None], axis=1)
    
    return sorted_orders, sorted_distances


def inference(instance: model.Reasoner, token: np.ndarray, depth: int):
    start = 10 if depth == 2 else 22 if depth == 3 else 46 if depth == 4 else 64
    end = 12 if depth == 2 else 25 if depth == 3 else 50 if depth == 4 else 64
    token = token[:, :start]  # Remove answers
    latent = instance.embed(token)

    @nnx.jit(static_argnames=['start_index', 'end_index'])
    def core(ins: model.Reasoner, lat: jnp.ndarray, start_index: int, end_index: int):
        batch = lat.shape[0]
        lat = jnp.concatenate([lat, jnp.zeros(shape=(batch, end_index - start_index, *lat.shape[2:]))], axis=1)

        def loop(idx: int, l: jnp.ndarray):
            n = ins.reason(l, mask=None, kv_len=jnp.full((batch,), idx + 1), q_len=jnp.full((batch,), idx + 1))
            return l.at[:, idx + 1].set(n[:, idx])

        lat = jax.lax.fori_loop(start_index - 1, end_index - 1, loop, lat)
        return lat[:, start_index:]
    return core(instance, latent, start, end)   # Shape: (batch, depth, feature)


def main(depth: int, batch_size: int, snapshot: str):
    with open('./configs/reasoner_tree.json', 'r') as file:
        config = json.load(file)

        feature = int(config['model']['Feature'])
        attn_feature = int(config['model']['ATTN Feature'])
        ffn_feature = int(config['model']['FFN Feature'])
        num_head = int(config['model']['Head Count'])
        decoder_count = int(config['model']['Decoder Count'])
        init_scalar = float(config['model']['Init Scalar'])
        max_len = int(config['model']['Max Length'])
        rope_base = float(config['model']['RoPE Base'])

    context_len = 14 if depth == 2 else 26 if depth == 3 else 52 if depth == 4 else 64

    # Build model instance
    key = jax.random.key(0)
    instance = model.Reasoner(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=32,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.eval(
        rope_base=rope_base,
        max_len=max_len
    )

    snap = Snapshot(os.path.dirname(snapshot))
    instance = snap.load(os.path.basename(snapshot), instance)

    # Tree generator and tokenizer setup
    tree_generator = TreeGenerator(
        depth=depth,
        node_width=2,  # Binary trees
        batch_size=batch_size
    )
    tree_tokenizer = TreeTokenizer(context_len=context_len)
    tree_strings = tree_generator.generate_trees()
    vocab_data, _, _, _ = tree_tokenizer.encode(tree_strings)
    vocabs = jnp.asarray(instance.embed(vocab_data))[0]  # Latent for A~Z, shape: (26, feature)

    children = find_children(tree_strings)
    data, _, _, _ = tree_tokenizer.encode(tree_strings)

    latents = inference(instance, data, depth)
    # Remove the first latent since it's tree root, no motivation for latent mixing
    latents = latents[:, 1:].reshape(-1, latents.shape[-1])

    sorted_orders, sorted_distances = decompose(latents, vocabs)
    composition = np.asarray(sorted_orders)
    composition = composition.reshape(batch_size, depth - 1, -1, 2).tolist()

    indices = []
    correct_major_contributor = 0
    for i, seq in enumerate(composition):
        print(f'Tree: {i}')
        print(f'    Targets: {children[i]}')
        for j, lat_comb_list in enumerate(seq):
            target_chars = children[i][j]
            if target_chars:
                # Convert target characters to vocabulary indices
                target_indices = [ord(c) - ord('A') for c in target_chars]
                
                # Try both possible orderings since find_children and decompose may order differently
                order1 = target_indices  # Original order
                order2 = target_indices[::-1]  # Reversed order
                
                # Find position in the sorted list for both orderings
                idx1 = safe_index(lat_comb_list, order1)
                idx2 = safe_index(lat_comb_list, order2)
                
                # Choose the smaller index if both exist, otherwise take the one that exists
                if idx1 is not None and idx2 is not None:
                    idx = min(idx1, idx2)
                elif idx1 is not None:
                    idx = idx1
                    correct_major_contributor += 1
                elif idx2 is not None:
                    idx = idx2
                else:
                    idx = None
                
                if idx is not None:
                    indices.append(idx)
                print(f'    Latent Position {j + 1}: Target "{target_chars}" -> {target_indices}; index: {idx}')
            else:
                print(f'    Latent Position {j + 1}: No target characters')

    if indices:
        print('=' * 64)
        print(f'Average position: {sum(indices) / len(indices):02f}/{len(list(combinations(range(vocabs.shape[0]), 2)))}')
        print(f'Percentage: {sum(indices) / len(indices) / len(list(combinations(range(vocabs.shape[0]), 2))) * 100:02f}%')
        print(f'Correct major contributor: {(correct_major_contributor / (batch_size * (depth - 1)) * 100):02f}%')
        print('=' * 64)
    else:
        print('=' * 64 + '\nNo valid indices found\n' + '=' * 64)


if __name__ == '__main__':
    main(
        4,
        10240,
        'snapshot/tree_search/reasoner/exp_sst_tree-search-d4/exp_tree-search_6000' # Do not include ".safetensors" suffix
    )
