import jax
import json
import numpy as np
import jax.numpy as jnp
from flax import nnx
from exp_utils import BenchmarkDataLoader
import os
import sys
from util import sample_tokens, load_models, add_gaussian_noise_to_array
from datetime import datetime


def inference_fn(reasoner, talker, data, target, indices, inference_step, noise_std):
    batch_indices = jnp.arange(indices.shape[0], dtype=jnp.int32)
    data = jnp.asarray(data)
    target = jnp.asarray(target)
    init_latent = reasoner.embed(data)
    init_key = jax.random.PRNGKey(0)

    # First loop: Reasoner inference
    def reasoner_body(i, state):
        input_latent, current_indices, key = state

        output_latent = jax.lax.stop_gradient(
            reasoner.reason(input_latent, q_len=current_indices, kv_len=current_indices)
        )

        max_magnitude = jnp.abs(jnp.max(output_latent))

        last_valid_latents = output_latent[batch_indices, current_indices - 1]
        last_valid_latents, key = add_gaussian_noise_to_array(array=last_valid_latents, noise_ratio=noise_std, key=key, magnitude=max_magnitude)
        new_input_latent = input_latent.at[batch_indices, current_indices].set(last_valid_latents)
        new_current_indices = current_indices + 1

        return new_input_latent, new_current_indices, key

    generated_latents, reasoner_final_indices, _ = jax.lax.fori_loop(
        0, inference_step, reasoner_body, (init_latent, indices, init_key)
    )

    # Second loop: Talker inference
    def talker_body(i, state):
        input_token, current_indices = state

        logits = jax.lax.stop_gradient(
            talker(generated_latents,
                   input_tokens=input_token,
                   token_len=current_indices,
                   latent_len=reasoner_final_indices)
        )

        last_valid_logits = logits[batch_indices, current_indices - 1]
        last_valid_logits = last_valid_logits[:, None, :]

        output_tokens = sample_tokens(last_valid_logits)
        output_tokens_squeezed = jnp.squeeze(output_tokens, axis=1)

        new_input_token = input_token.at[batch_indices, current_indices].set(output_tokens_squeezed)
        new_current_indices = current_indices + 1

        return (new_input_token, new_current_indices)

    generated_tokens, final_indices = jax.lax.fori_loop(
        0, inference_step, talker_body, (data, indices)
    )

    # Accuracy calculation
    batch_size, seq_len = generated_tokens.shape[:2]
    positions = jnp.arange(seq_len)[None, :]
    start_positions = (final_indices - inference_step + 4)[:, None]
    end_positions = final_indices[:, None]

    mask = (positions >= start_positions) & (positions < end_positions)

    generated_masked = jnp.where(mask, generated_tokens, -1)
    target_masked = jnp.where(mask, target, -1)
    correct = jnp.sum((generated_masked == target_masked) & mask)
    total = jnp.sum(mask)

    return correct / total


def benchmark(reasoner, talker, metrics, data, target, indices, inference_step, noise_std):
    mean_accuracy = inference_fn(reasoner, talker, data, target, indices, inference_step, noise_std)
    metrics.update(accuracy=mean_accuracy)


def run_single_benchmark(reasoner, talker, eval_loader, total_steps, inference_step, noise_std):
    """Run a single benchmark configuration and return the accuracy."""
    metrics = nnx.MultiMetric(accuracy=nnx.metrics.Average('accuracy'))

    # Benchmark loop
    for step in range(total_steps):
        data, target, mask, indices = next(eval_loader)
        benchmark(reasoner, talker, metrics, data, target, indices, inference_step, noise_std)

    return metrics.compute().get('accuracy')


def main(reasoner_snap: str = None, talker_snap: str = None, size: str = 'large', inference_steps=None,
         noise_stds=None):
    with open('./configs/benchmark.json', 'r') as file:
        config = json.load(file)
        context_len = int(config['reasoner-cfg']['Context Length'])
        total_steps = int(config['reasoner-cfg']['Total Steps'])
        batch_size = int(config['reasoner-cfg']['Batch Size'])

    # Load models once at the beginning
    reasoner, talker = load_models(reasoner_snap, talker_snap, size=size)

    # Store results
    results = []
    eval_loader = BenchmarkDataLoader(
        path=[PATH_TO_CFG_DATA],
        pattern='cfg_test*.parquet',
        pad_token=0,
        mem=128,
        batch_size=batch_size,
        context_len=context_len,
        threads=1,
        dtype=np.int32,
        inference_steps=1
    )

    # Create output filename with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_filename = f"./results/benchmark_{size}_{timestamp}.txt"

    # Prepare output content
    output_lines = []

    header_info = [
        "Running JEPA-Reasoner benchmark sweep...",
        "=" * 70,
        f"Reasoner: {reasoner_snap}",
        f"Talker: {talker_snap}",
        f"Model Size: {size}",
        "=" * 70
    ]

    # Print to console and store for file
    for line in header_info:
        print(line)
        output_lines.append(line)

    # Nested loop for parameter sweep
    for inference_step in inference_steps:
        eval_loader.set_inference_steps(inference_step)

        for noise_std in noise_stds:
            status_line = f"\nRunning: Noise={noise_std}, Steps={inference_step}"
            separator_line = "-" * 50

            print(status_line)
            print(separator_line)
            output_lines.append(status_line)
            output_lines.append(separator_line)

            try:
                accuracy = run_single_benchmark(reasoner, talker, eval_loader, total_steps, inference_step,
                                                noise_std)

                results.append({
                    'noise_fraction': noise_std,
                    'inference_steps': inference_step,
                    'accuracy': accuracy
                })

                result_line = f"  Final Accuracy: {accuracy:.4f}"
                print(result_line)
                output_lines.append(result_line)

            except Exception as e:
                error_line = f"  Error occurred: {e}"
                print(error_line)
                output_lines.append(error_line)

                results.append({
                    'noise_fraction': noise_std,
                    'inference_steps': inference_step,
                    'accuracy': 0.0
                })

    # Format results table
    results_header = [
        "\n" + "=" * 60,
        f"BENCHMARK RESULTS FOR MODEL R_{size}",
        "=" * 60,
        f"{'Noise Fraction':<15} {'Gen Steps':<10} {'Accuracy':<10}",
        "-" * 35
    ]

    for line in results_header:
        print(line)
        output_lines.append(line)

    for result in results:
        result_line = f"{result['noise_fraction']:<15.2f} {result['inference_steps']:<10} {result['accuracy']:<10.4f}"
        print(result_line)
        output_lines.append(result_line)

    # Write all output to file
    try:
        with open(output_filename, 'w') as f:
            f.write('\n'.join(output_lines))

        success_msg = f"\nResults saved to: {output_filename}"
        print(success_msg)

    except Exception as e:
        error_msg = f"\nError saving results to file: {e}"
        print(error_msg)

    eval_loader.stop()
    sys.exit(0)


if __name__ == '__main__':
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.99'
    PATH_TO_CFG_DATA = './data'
    model_size = sys.argv[1]
    reasoner_snap = sys.argv[2]
    talker_snap = sys.argv[3]

    noise_std = [0.0, 0.05, 0.1, 0.15]
    inference_steps = [8]

    main(
        reasoner_snap=reasoner_snap,
        talker_snap=talker_snap,
        size=model_size,
        noise_stds=noise_std,
        inference_steps=inference_steps,
    )