import os
import jax
import json
import model
import numpy as np
from flax import nnx
import jax.numpy as jnp
from snapshot import Snapshot
from itertools import combinations
from exp_utils import TreeGenerator, TreeTokenizer
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

BLUE = '#174EA6'
LIGHT_BLUE = '#4285F4'
RED = '#A50E0E'
LIGHT_RED = '#EA4335'
GREEN = '#0D652D'
LIGHT_GREEN = '#34A853'

plt.rcParams['font.family'] = 'Serif'
plt.rcParams['font.size'] = 20


def safe_index(lst, item):
    try:
        return lst.index(item)
    except ValueError:
        return None


def find_children(tree_strings):
    """ Find children of nodes in the [ROUTE] for each tree string. """
    if isinstance(tree_strings, str):
        tree_strings = [tree_strings]
        single_input = True
    else:
        single_input = False

    results = []
    for tree_string in tree_strings:
        parts = tree_string.split('[ROUTE]')
        if len(parts) != 2:
            results.append([])
            continue
        route = parts[1]
        before_root = parts[0].split('[ROOT]')[0]

        parent_child_map = {}
        if before_root:
            pairs = before_root.split(',')
            for pair in pairs:
                if len(pair) == 2:
                    parent, child = pair[0], pair[1]
                    if parent not in parent_child_map:
                        parent_child_map[parent] = []
                    parent_child_map[parent].append(child)

        route_children = []
        for i in range(len(route) - 1):
            node = route[i]
            if node in parent_child_map:
                route_set = set(route)
                children_list = parent_child_map[node]
                route_children_list = [child for child in children_list if child in route_set]
                non_route_children_list = [child for child in children_list if child not in route_set]
                sorted_route_children = sorted(route_children_list)
                sorted_non_route_children = sorted(non_route_children_list)
                children = ''.join(sorted_route_children + sorted_non_route_children)
                route_children.append(children)
            else:
                route_children.append('')
        results.append(route_children)

    if single_input:
        return results[0]
    else:
        return results


def decompose_with_coefficients(latent, vocab):
    """ Enhanced decompose function that returns coefficients for visualization. """
    batch_size, latent_dims = latent.shape
    vocab_size = vocab.shape[0]

    vocab_pairs = jnp.array(list(combinations(range(vocab_size), 2)))
    num_combinations = vocab_pairs.shape[0]

    v1 = vocab[vocab_pairs[:, 0]]
    v2 = vocab[vocab_pairs[:, 1]]
    A_matrices = jnp.stack([v1, v2], axis=-1)

    distances = jnp.zeros((batch_size, num_combinations))
    orders = jnp.zeros((batch_size, num_combinations, 2), dtype=jnp.int32)
    coefficients = jnp.zeros((batch_size, num_combinations, 2))
    projections = jnp.zeros((batch_size, num_combinations, latent_dims))

    for comb_idx in range(num_combinations):
        A = A_matrices[comb_idx]
        ATA = A.T @ A
        ATb = A.T @ latent.T

        ATA = np.array(ATA.astype(jnp.float32), dtype=np.float32)
        ATb = np.array(ATb.astype(jnp.float32), dtype=np.float32)

        coeffs = np.linalg.solve(ATA, ATb).T
        coeffs = np.maximum(coeffs, 0)
        coefficients = coefficients.at[:, comb_idx, :].set(coeffs)

        proj = coeffs @ A.T
        projections = projections.at[:, comb_idx, :].set(proj)

        batch_distances = jnp.linalg.norm(latent - proj, axis=1)
        distances = distances.at[:, comb_idx].set(batch_distances)

        i, j = vocab_pairs[comb_idx]
        alpha_greater = coeffs[:, 0] > coeffs[:, 1]

        comb_orders = jnp.where(
            alpha_greater[:, None],
            jnp.array([i, j])[None, :],
            jnp.array([j, i])[None, :]
        )
        orders = orders.at[:, comb_idx, :].set(comb_orders)

    sorted_indices = jnp.argsort(distances, axis=1)
    sorted_distances = jnp.take_along_axis(distances, sorted_indices, axis=1)
    sorted_orders = jnp.take_along_axis(orders, sorted_indices[:, :, None], axis=1)
    sorted_coefficients = jnp.take_along_axis(coefficients, sorted_indices[:, :, None], axis=1)
    sorted_projections = jnp.take_along_axis(projections, sorted_indices[:, :, None], axis=1)

    return sorted_orders, sorted_distances, sorted_coefficients, sorted_projections, vocab_pairs


def inference(instance: model.Reasoner, token: np.ndarray, depth: int):
    start = 10 if depth == 2 else 22 if depth == 3 else 46 if depth == 4 else 64
    end = 12 if depth == 2 else 25 if depth == 3 else 50 if depth == 4 else 64
    token = token[:, :start]
    latent = instance.embed(token)

    @nnx.jit(static_argnames=['start_index', 'end_index'])
    def core(ins: model.Reasoner, lat: jnp.ndarray, start_index: int, end_index: int):
        batch = lat.shape[0]
        lat = jnp.concatenate([lat, jnp.zeros(shape=(batch, end_index - start_index, *lat.shape[2:]))], axis=1)

        def loop(idx: int, l: jnp.ndarray):
            n = ins.reason(l, mask=None, kv_len=jnp.full((batch,), idx + 1), q_len=jnp.full((batch,), idx + 1))
            return l.at[:, idx + 1].set(n[:, idx])

        lat = jax.lax.fori_loop(start_index - 1, end_index - 1, loop, lat)
        return lat[:, start_index:]
    return core(instance, latent, start, end)


def plot_2d_latent_space_matplotlib(latents, projections, vocab,
                                   save_path,
                                   figsize=(10, 8)):
    """ Create publication-ready 2D visualization using matplotlib. """

    # Apply PCA to reduce to 2D
    combined_data = np.vstack([
        latents.reshape(-1, latents.shape[-1]),
        vocab,
        projections.reshape(-1, projections.shape[-1])
    ])

    pca = PCA(n_components=2)
    reduced_data = pca.fit_transform(combined_data)

    n_latents = latents.reshape(-1, latents.shape[-1]).shape[0]
    n_vocab = vocab.shape[0]

    latent_2d = reduced_data[:n_latents]
    vocab_2d = reduced_data[n_latents:n_latents + n_vocab]
    proj_2d = reduced_data[n_latents + n_vocab:]

    # Create figure with custom styling
    fig, ax = plt.subplots(figsize=figsize)
    fig.subplots_adjust(left=0.1, right=0.95, bottom=0.15, top=0.95)

    vocab_color = RED            # Red
    latent_color = BLUE          # Blue

    # Plot predicted latents (middle layer)
    n_sample_latents = min(400, len(latent_2d))
    sample_indices = np.random.choice(len(latent_2d), n_sample_latents, replace=False)
    sampled_latents = latent_2d[sample_indices]

    latent_scatter = ax.scatter(sampled_latents[:, 0], sampled_latents[:, 1],
                               c=latent_color, s=50, alpha=0.7, marker='o',
                               label='Predicted Latents', edgecolors='white', linewidths=0.5)

    # Plot vocabulary latents with labels (top layer)
    vocab_scatter = ax.scatter(vocab_2d[:, 0], vocab_2d[:, 1],
                              c=vocab_color, s=150, marker='D', alpha=0.9,
                              edgecolors='black', linewidths=1.5,
                              label='Vocabulary Latents', zorder=5)

    # Add vocabulary labels
    for i, (x, y) in enumerate(vocab_2d):
        ax.annotate(chr(ord('A') + i), (x, y), xytext=(5, 5),
                   textcoords='offset points', fontsize=11, fontweight='bold',
                   ha='left', va='bottom', zorder=6,
                   bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8, edgecolor='none'))

    # Add some connection lines to show relationships
    if len(latent_2d) > 0:
        for i in range(min(15, len(latent_2d))):  # Show connections for first 15 points
            # Connect latent to its closest projection
            closest_proj_idx = np.argmin(np.sum((proj_2d - latent_2d[i])**2, axis=1))
            ax.plot([latent_2d[i, 0], proj_2d[closest_proj_idx, 0]],
                   [latent_2d[i, 1], proj_2d[closest_proj_idx, 1]],
                   'gray', alpha=0.3, linewidth=0.8, zorder=1)

    # Customize axes
    ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)', fontsize=18)
    ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)', fontsize=18)

    # Customize legend
    ax.legend(loc='upper right', fontsize=18)

    # Adjust layout
    plt.tight_layout()

    # Save with high DPI for publication
    plt.savefig(save_path, dpi=400, bbox_inches='tight')
    plt.show()

    # Print statistics
    total_variance = sum(pca.explained_variance_ratio_[:2])
    print(f"Total variance explained by first 2 PCs: {total_variance:.1%}")
    print(f"PC contributions: PC1={pca.explained_variance_ratio_[0]:.1%}, "
          f"PC2={pca.explained_variance_ratio_[1]:.1%}")


def main_2d_plot(depth: int, batch_size: int, snapshot: str, output_dir: str = 'paper_figures'):
    """ Generate 2D visualization for paper. """

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    print("Loading model and generating data...")

    # Load model configuration
    with open('./configs/reasoner_tree.json', 'r') as file:
        config = json.load(file)
        feature = int(config['model']['Feature'])
        attn_feature = int(config['model']['ATTN Feature'])
        ffn_feature = int(config['model']['FFN Feature'])
        num_head = int(config['model']['Head Count'])
        decoder_count = int(config['model']['Decoder Count'])
        init_scalar = float(config['model']['Init Scalar'])
        max_len = int(config['model']['Max Length'])
        rope_base = float(config['model']['RoPE Base'])

    context_len = 14 if depth == 2 else 26 if depth == 3 else 52 if depth == 4 else 64

    # Build model
    key = jax.random.key(0)
    instance = model.Reasoner(
        feature=feature, attn_feature=attn_feature, ffn_feature=ffn_feature,
        num_head=num_head, decoder_count=decoder_count, is_causal=True,
        init_scalar=init_scalar, vocab_size=32, key=key, dtype=jnp.bfloat16
    )
    instance.eval(rope_base=rope_base, max_len=max_len)

    # Load snapshot
    snap = Snapshot(os.path.dirname(snapshot))
    instance = snap.load(os.path.basename(snapshot), instance)

    # Generate tree data
    tree_generator = TreeGenerator(depth=depth, node_width=2, batch_size=batch_size)
    tree_tokenizer = TreeTokenizer(context_len=context_len)

    # Get vocabulary embeddings
    char_list = 'ABCDEFGHIJKLMNO'
    vocab_data, _, _, _ = tree_tokenizer.encode([char_list])
    vocabs = jnp.asarray(instance.embed(vocab_data[:, :15]))[0]

    # Generate reasoning data
    tree_strings = tree_generator.generate_trees()
    children = find_children(tree_strings)
    data, _, _, _ = tree_tokenizer.encode(tree_strings)

    print("Running inference...")
    latents = inference(instance, data, depth)
    latents = latents[:, 1:].reshape(-1, latents.shape[-1])  # Remove root latent

    print("Decomposing latent vectors...")
    # Get projections and other decomposition results
    sorted_orders, sorted_distances, sorted_coefficients, sorted_projections, vocab_pairs = decompose_with_coefficients(latents, vocabs)

    # Reshape for visualization
    projs = np.asarray(sorted_projections).reshape(batch_size, depth - 1, -1, latents.shape[-1])

    print("Creating 2D visualization...")
    # Create matplotlib 2D plot
    plot_2d_latent_space_matplotlib(
        latents.reshape(batch_size, depth-1, -1),
        projs,
        vocabs,
        save_path=os.path.join(output_dir, 'latent_space_2d.jpg'),
        figsize=(10, 8)
    )


if __name__ == '__main__':
    main_2d_plot(
        depth=4,
        batch_size=2048,  # Smaller batch for cleaner visualization
        snapshot='snapshot/tree_search/reasoner/exp_sst_tree-search-d4/exp_tree-search_6000',
        output_dir='./results'
    )