import jax
import json
import numpy as np
import model
import jax.numpy as jnp
from flax import nnx
from exp_utils import BenchmarkDataLoader
from snapshot import Snapshot
import os
import sys
from util import sample_tokens, add_noise_to_data
from datetime import datetime


def inference_fn(instance, data, target, indices, inference_step):
    batch_indices = jnp.arange(indices.shape[0], dtype=jnp.int32)

    def loop_body(i, carry):
        input_token, current_indices = carry
        logits = jax.lax.stop_gradient(instance(input_token, q_len=current_indices, kv_len=current_indices))

        # Gather from the actual last valid position
        last_valid_logits = logits[batch_indices, current_indices - 1]  # Shape: (batch_size, vocab_size)
        last_valid_logits = last_valid_logits[:, None, :]  # Add sequence dimension back: (batch_size, 1, vocab_size)

        output_tokens = sample_tokens(last_valid_logits)

        # Squeeze the output_tokens to ensure correct shape and use current_indices directly
        output_tokens_squeezed = jnp.squeeze(output_tokens, axis=1)  # Remove the sequence dimension
        input_token = input_token.at[batch_indices, current_indices].set(output_tokens_squeezed)

        current_indices += 1
        return input_token, current_indices

    generated_tokens, final_indices = jax.lax.fori_loop(0, inference_step, loop_body, (data, indices))

    # Create a mask for the inference steps for each batch item
    batch_size, seq_len = generated_tokens.shape[:2]
    positions = jnp.arange(seq_len)[None, :]  # (1, seq_len)
    start_positions = (final_indices - inference_step)[:, None]  # (batch_size, 1)
    end_positions = final_indices[:, None]  # (batch_size, 1)

    mask = (positions >= start_positions) & (positions < end_positions)  # (batch_size, seq_len)

    # Apply mask and calculate accuracy
    generated_masked = jnp.where(mask, generated_tokens, -1)
    target_masked = jnp.where(mask, target, -1)
    correct = jnp.sum((generated_masked == target_masked) & mask)
    total = jnp.sum(mask)

    return correct / total


def benchmark(instance, metrics, data, target, indices, inference_step, noise_fraction):
    data = add_noise_to_data(data, indices, [14, 15, 16], noise_fraction, 0)
    mean_accuracy = inference_fn(instance, data, target, indices, inference_step)
    metrics.update(accuracy=mean_accuracy)


def run_single_benchmark(instance, eval_loader, total_steps, inference_step, noise_fraction):
    """Run a single benchmark configuration and return the accuracy."""
    metrics = nnx.MultiMetric(accuracy=nnx.metrics.Average('accuracy'))

    # Benchmark loop
    for step in range(total_steps):
        data, target, mask, indices = next(eval_loader)
        benchmark(instance, metrics, data, target, indices, inference_step, noise_fraction)

    return metrics.compute().get('accuracy')


def main(init_snap: str = None, size: str = 'large', noise_fractions=None, inference_steps=None):
    # Load configuration
    with open('./configs/transformer_cfg.json', 'r') as file:
        config = json.load(file)

        feature = int(config[f'model_{size}']['Feature'])
        attn_feature = int(config[f'model_{size}']['ATTN Feature'])
        ffn_feature = int(config[f'model_{size}']['FFN Feature'])
        num_head = int(config[f'model_{size}']['Head Count'])
        decoder_count = int(config[f'model_{size}']['Decoder Count'])
        init_scalar = float(config[f'model_{size}']['Init Scalar'])
        max_len = int(config[f'model_{size}']['Max Length'])
        rope_base = float(config[f'model_{size}']['RoPE Base'])

    with open('./configs/benchmark.json', 'r') as file:
        config = json.load(file)
        total_steps = int(config['transformer-cfg']['Total Steps'])
        context_len = int(config['transformer-cfg']['Context Length'])
        batch_size = int(config['transformer-cfg']['Batch Size'])

    # Build model instance
    key = jax.random.key(0)
    instance = model.Transformer(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=32,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.eval(
        rope_base=rope_base,
        max_len=max_len
    )

    if init_snap:
        snap = Snapshot(os.path.dirname(init_snap))
        instance = snap.load(os.path.basename(init_snap), instance)

    # Store results
    results = []

    eval_loader = BenchmarkDataLoader(
        path=[PATH_TO_CFG_DATA],
        pattern='cfg_test*.parquet',
        pad_token=0,
        mem=2048,
        batch_size=batch_size,
        context_len=context_len,
        threads=1,
        dtype=np.int32,
        inference_steps=1
    )

    # Create output filename with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_filename = f"./results/transformer_{size}_results_{timestamp}.txt"

    # Prepare output content
    output_lines = []

    header_info = [
        "Running benchmark sweep...",
        "=" * 60,
        f"Model Size: {size}",
        f"Snapshot: {init_snap if init_snap else 'None'}",
        "=" * 60
    ]

    # Print to console and store for file
    for line in header_info:
        print(line)
        output_lines.append(line)

    # Nested loop for parameter sweep
    for inference_step in inference_steps:
        eval_loader.set_inference_steps(inference_step)

        for noise_fraction in noise_fractions:
            status_line = f"Running: Noise={noise_fraction}, Steps={inference_step}"
            print(status_line)
            output_lines.append(status_line)

            try:
                accuracy = run_single_benchmark(instance, eval_loader, total_steps, inference_step, noise_fraction)

                results.append({
                    'noise_fraction': noise_fraction,
                    'inference_steps': inference_step,
                    'accuracy': accuracy
                })

                result_line = f"  Accuracy: {accuracy:.4f}"
                print(result_line)
                output_lines.append(result_line)

            except Exception as e:
                error_line = f"  Error occurred: {e}"
                print(error_line)
                output_lines.append(error_line)

                results.append({
                    'noise_fraction': noise_fraction,
                    'inference_steps': inference_step,
                    'accuracy': 0.0
                })

    # Format results table
    results_header = [
        "\n" + "=" * 60,
        f"BENCHMARK RESULTS FOR MODEL T_{size}",
        "=" * 60,
        f"{'Noise Fraction':<15} {'Gen Steps':<10} {'Accuracy':<10}",
        "-" * 35
    ]

    for line in results_header:
        print(line)
        output_lines.append(line)

    for result in results:
        result_line = f"{result['noise_fraction']:<15.2f} {result['inference_steps']:<10} {result['accuracy']:<10.4f}"
        print(result_line)
        output_lines.append(result_line)

    # Write all output to file
    try:
        with open(output_filename, 'w') as f:
            f.write('\n'.join(output_lines))

        success_msg = f"\nResults saved to: {output_filename}"
        print(success_msg)

    except Exception as e:
        error_msg = f"\nError saving results to file: {e}"
        print(error_msg)

    eval_loader.stop()
    sys.exit(0)


if __name__ == '__main__':
    PATH_TO_CFG_DATA = './data'
    model_size = sys.argv[1]
    transformer_snap = sys.argv[2]
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.99'

    noise_fractions = [0.0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30]
    inference_steps = [1, 2, 3, 4]
    main(transformer_snap, size=model_size, noise_fractions=noise_fractions, inference_steps=inference_steps)