import sys
import jax
import json
import numpy as np
import model
import jax.numpy as jnp
from flax import nnx
from exp_utils import CoconutDataLoaderMultiAr
from snapshot import Snapshot
import os
from util import sample_tokens, add_gaussian_noise_to_array


def inference_fn(instance, data, target, indices, soft_thinking_step, token_generation_step, noise_std):
    input_latent = instance.embed(data)
    batch_indices = jnp.arange(indices.shape[0], dtype=jnp.int32)
    init_key = jax.random.PRNGKey(0)

    def soft_thinking_loop_body(i, carry):
        input_latent, current_indices, key = carry
        result = jax.lax.stop_gradient(instance.latent_reasoning(input_latent, q_len=current_indices, kv_len=current_indices))
        max_magnitude = jnp.abs(jnp.max(result))
        result, key = add_gaussian_noise_to_array(array=result, key=key, noise_ratio=noise_std, magnitude=max_magnitude)
        # Update input_latent at the specified positions
        updated_latent = input_latent.at[batch_indices, current_indices[:, None], :].set(result[batch_indices, current_indices[:, None] - 1, :])
        return updated_latent, current_indices + 1, key

    # Initial carry: (input_latent, indices)
    final_latent, final_indices, final_key = jax.lax.fori_loop(0, soft_thinking_step, soft_thinking_loop_body, (input_latent, indices, init_key))

    def token_generation_loop_body(i, carry):
        input_latent, current_indices, key = carry
        logits = jax.lax.stop_gradient(
            instance.assemble(instance.latent_reasoning(input_latent, q_len=current_indices, kv_len=current_indices))
        )
        # Update input_latent at the specified positions
        # Gather from the actual last valid position
        last_valid_logits = logits[batch_indices, current_indices - 1]  # Shape: (batch_size, vocab_size)
        last_valid_logits = last_valid_logits[:, None, :]  # Add sequence dimension back: (batch_size, 1, vocab_size)

        output_tokens = sample_tokens(last_valid_logits)
        result_embedding = jax.lax.stop_gradient(instance.embed(output_tokens))
        max_magnitude = jnp.abs(jnp.max(result_embedding))
        result_embedding, key = add_gaussian_noise_to_array(array=result_embedding, key=key, noise_ratio=noise_std, magnitude=max_magnitude)
        updated_latent = input_latent.at[batch_indices, current_indices[:, None], :].set(result_embedding[batch_indices, current_indices[:, None] - 1, :])
        return updated_latent, current_indices + 1, key

    final_latent, final_indices, _ = jax.lax.fori_loop(0, token_generation_step - 1, token_generation_loop_body, (final_latent, final_indices, final_key))

    logits = jax.lax.stop_gradient(instance.assemble(instance.latent_reasoning(final_latent, q_len=final_indices, kv_len=final_indices)))
    generated_tokens = sample_tokens(logits)
    batch_size, seq_len = generated_tokens.shape[:2]
    positions = jnp.arange(seq_len)[None, :]
    start_positions = (final_indices - token_generation_step)[:, None]
    end_positions = final_indices[:, None]

    mask = (positions >= start_positions) & (positions < end_positions)

    generated_masked = jnp.where(mask, generated_tokens, -1)
    target_masked = jnp.where(mask, target, -1)
    correct = jnp.sum((generated_masked == target_masked) & mask)
    total = jnp.sum(mask)

    return correct / total


def benchmark(instance, metrics, data, target, indices, soft_thinking_step, token_generation_step, noise_std):
    mean_accuracy = inference_fn(instance, data, target, indices, soft_thinking_step, token_generation_step, noise_std)
    metrics.update(accuracy=mean_accuracy)


def main(init_snap: str=None, size: str = 'large', soft_thinking_step: int=4, token_generation_step: int=4, noise_stds: list[float]=None):
    # Load configuration
    with open('./configs/transformer_cfg.json', 'r') as file:
        config = json.load(file)

        feature = int(config[f'model_{size}']['Feature'])
        attn_feature = int(config[f'model_{size}']['ATTN Feature'])
        ffn_feature = int(config[f'model_{size}']['FFN Feature'])
        num_head = int(config[f'model_{size}']['Head Count'])
        decoder_count = int(config[f'model_{size}']['Decoder Count'])
        init_scalar = float(config[f'model_{size}']['Init Scalar'])
        max_len = int(config[f'model_{size}']['Max Length'])
        rope_base = float(config[f'model_{size}']['RoPE Base'])

    with open('./configs/benchmark.json', 'r') as file:
        config = json.load(file)
        context_len = int(config['coco-cfg']['Context Length'])
        total_steps = int(config['coco-cfg']['Total Steps'])
        batch_size = int(config['coco-cfg']['Batch Size'])

    # Build model instance
    key = jax.random.key(0)
    instance = model.Transformer(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=32,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.eval(
        rope_base=rope_base,
        max_len=max_len
    )

    if init_snap:
        snap = Snapshot(os.path.dirname(init_snap))
        instance = snap.load(os.path.basename(init_snap), instance)

    metrics = nnx.MultiMetric(accuracy=nnx.metrics.Average('accuracy'))

    # Data loader setup
    eval_loader = CoconutDataLoaderMultiAr(
        path=[PATH_TO_CFG_DATA],
        pattern='cfg_test*.parquet',
        pad_token=0,
        mem=128,
        batch_size=batch_size,
        context_len=context_len,
        threads=2,
        dtype=np.int32,
        soft_thinking_steps=SOFT_THINK_STEP + TOKEN_GEN_STEP - 1
    )

    for noise in noise_stds:
        for step in range(total_steps):
            data, target, mask, indices = next(eval_loader)
            benchmark(instance, metrics, data, target, indices, soft_thinking_step, token_generation_step, noise)

        print(f'total accuracy at noise_level {noise}:{metrics.compute().get('accuracy')}')

    eval_loader.stop()
    sys.exit(0)


if __name__ == '__main__':
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.99'
    PATH_TO_CFG_DATA = './data'
    model_size = sys.argv[1]
    model_snap = sys.argv[2]
    SOFT_THINK_STEP = 4
    TOKEN_GEN_STEP = 4
    noise_std = [0.0, 0.05, 0.1, 0.15]
    main(
        init_snap=model_snap,
        size=model_size,
        soft_thinking_step=SOFT_THINK_STEP,
        token_generation_step=TOKEN_GEN_STEP,
        noise_stds=noise_std
    )