import os
import jax
import json
import model
import numpy as np
import jax.numpy as jnp
from flax import nnx
from tqdm import tqdm
from snapshot import Snapshot
from exp_utils import TreeGenerator, TreeTokenizer


os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'


def inference_reasoner(instance, token, start, end):
    latent = instance.embed(token)
    batch = latent.shape[0]
    latent = jnp.concatenate([latent, jnp.zeros(shape=(batch, end - start, *latent.shape[2:]))], axis=1)

    def loop(idx: int, l: jnp.ndarray):
        n = instance.reason(l, mask=None, kv_len=jnp.full((batch,), idx), q_len=jnp.full((batch,), idx))
        return l.at[:, idx].set(n[:, idx - 1])

    latent = jax.lax.fori_loop(start, end, loop, latent)
    return latent[:, start:end]   # Shape: (batch, depth, feature)


def inference_talker(instance, latent):
    logits = instance(latent)
    probs = jax.nn.softmax(logits, axis=-1)
    return jnp.argmax(probs, axis=-1)


@nnx.jit(static_argnames=['start', 'end'])
def inference(reasoner, talker, data, start, end):
    latent = inference_reasoner(reasoner, data, start, end)
    return inference_talker(talker, latent)


def build_reasoner():
    with open('./configs/reasoner_tree.json', 'r') as file:
        config = json.load(file)

        feature = int(config['model']['Feature'])
        attn_feature = int(config['model']['ATTN Feature'])
        ffn_feature = int(config['model']['FFN Feature'])
        num_head = int(config['model']['Head Count'])
        decoder_count = int(config['model']['Decoder Count'])
        init_scalar = float(config['model']['Init Scalar'])
        max_len = int(config['model']['Max Length'])
        rope_base = float(config['model']['RoPE Base'])

    key = jax.random.key(0)
    instance = model.Reasoner(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=32,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.eval(rope_base=rope_base, max_len=max_len)
    return instance


def build_talker():
    with open('./configs/reasoner_tree.json', 'r') as file:
        config = json.load(file)
        latent_feature = int(config['model']['Feature'])
    with open('./configs/talker_tree.json', 'r') as file:
        config = json.load(file)

        feature = int(config['model']['Feature'])
        attn_feature = int(config['model']['ATTN Feature'])
        ffn_feature = int(config['model']['FFN Feature'])
        num_head = int(config['model']['Head Count'])
        decoder_count = int(config['model']['Decoder Count'])
        init_scalar = float(config['model']['Init Scalar'])
        max_len = int(config['model']['Max Length'])
        rope_base = float(config['model']['RoPE Base'])

    key = jax.random.key(0)
    instance = model.MonoTalker(
        feature=feature,
        latent_feature=latent_feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        init_scalar=init_scalar,
        is_causal=False,
        vocab_size=32,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.eval(rope_base=rope_base, max_len=max_len)
    return instance


def main(depth: int, batch_size: int, reasoner_snapshot: str, talker_snapshot_base: str):
    start = 10 if depth == 2 else 22 if depth == 3 else 46 if depth == 4 else 64
    end = 12 if depth == 2 else 25 if depth == 3 else 50 if depth == 4 else 64
    context_len = 14 if depth == 2 else 26 if depth == 3 else 52 if depth == 4 else 64

    # Build models
    reasoner = build_reasoner()
    talker = build_talker()

    # Load JEPA-Reasoner's weight
    snap = Snapshot(os.path.dirname(reasoner_snapshot))
    reasoner = snap.load(os.path.basename(reasoner_snapshot), reasoner)

    # Setup snapshot list
    snap = Snapshot(talker_snapshot_base)
    snap_names = os.listdir(talker_snapshot_base)
    if len(snap_names) == 0:
        print('No snapshots found!')
        return

    for i in range(len(snap_names)):
        if snap_names[i].endswith('.safetensors'):
            snap_names[i] = snap_names[i].split('.')[0]
        else:
            snap_names.remove(snap_names[i])

    # Initialize TreeGenerator and TreeTokenizer
    tree_generator = TreeGenerator(
        depth=depth,
        node_width=2,   # Binary trees
        batch_size=batch_size
    )
    tree_tokenizer = TreeTokenizer(context_len=context_len)
    accuracies = []

    for s in tqdm(snap_names, desc="Evaluating snapshots"):
        talker = snap.load(s, talker)
        # Generate test data
        tree_strings = tree_generator.generate_trees()
        data, _, _, _ = tree_tokenizer.encode(tree_strings)
        # Inference
        pred_tokens = inference(reasoner, talker, data, start, end)
        pred_list = tree_tokenizer.decode(np.asarray(pred_tokens))
        # Statistics
        correct_items = 0
        for i in range(len(tree_strings)):
            if tree_strings[i].split('[ROUTE]')[-1] == pred_list[i]:
                correct_items += 1
        accuracies.append(correct_items / batch_size)

    print(f'Evaluated {len(accuracies)} snapshots, accuracies:')
    for i in range(len(accuracies)):
        print(f'No. {i * 200} - {accuracies[i]}')
    print(f'Evaluation finished: Highest accuracy: {max(accuracies)}, from snapshot {accuracies.index(max(accuracies)) * 200})')


if __name__ == '__main__':
    main(4,
         10240,
         'snapshot/tree_search/reasoner/exp_sst_tree-search-d4/exp_tree-search_6000',
         'snapshot/tree_search/talker/train'
    )
