"""Analyze loss as a function of context depth for Llama 3.2 1B."""

import argparse
import pickle
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

from fma_llama.model.config import LlamaConfig
from fma_llama.model.llama import LlamaForCausalLM


def load_model(checkpoint_dir: str):
    """Load Flax model and config from checkpoint directory."""
    checkpoint_path = Path(checkpoint_dir)

    with open(checkpoint_path / "config.pkl", 'rb') as f:
        config = pickle.load(f)

    with open(checkpoint_path / "flax_params.pkl", 'rb') as f:
        params = pickle.load(f)

    model = LlamaForCausalLM(config)

    return model, params, config


def compute_per_position_loss(
    model,
    params,
    input_ids: jnp.ndarray,
    eos_token_id: int,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Compute per-position cross-entropy loss with masking.

    Args:
        model: Flax model
        params: Model parameters
        input_ids: Input token IDs of shape (batch_size, seq_len)
        eos_token_id: Token ID for end-of-sequence

    Returns:
        Tuple of (loss_per_position, mask) both of shape (batch_size, seq_len-1)
        mask is True for tokens up to and including EOS, False after
    """
    # Get logits
    logits = model.apply(params, input_ids)

    # Compute cross-entropy loss per position
    # Shift logits and labels for language modeling
    shift_logits = logits[:, :-1, :]  # (batch, seq_len-1, vocab)
    shift_labels = input_ids[:, 1:]    # (batch, seq_len-1)

    # Create mask: True up to and including first EOS, False after
    # Find first EOS position in each sequence
    is_eos = shift_labels == eos_token_id
    # cumsum to mark all positions after first EOS
    eos_and_after = jnp.cumsum(is_eos, axis=1)
    # mask is True up to first EOS (where eos_and_after < 1)
    mask = eos_and_after < 1

    # Compute per-position loss
    log_probs = jax.nn.log_softmax(shift_logits, axis=-1)
    loss_per_position = -jnp.take_along_axis(
        log_probs,
        shift_labels[:, :, None],
        axis=-1
    ).squeeze(-1)  # (batch, seq_len-1)

    return loss_per_position, mask


def segexpmeanlog(values, segment_ids, num_segments):
    """Compute geometric mean of values within each segment."""
    logvalues = jnp.log(values)
    segsum = jax.ops.segment_sum(logvalues, segment_ids, num_segments)
    segcnt = jax.ops.segment_sum(jnp.ones_like(values), segment_ids, num_segments)
    return jnp.exp(segsum / segcnt)


def fit_irreducible(x, y):
    """Fit power law to estimate irreducible loss."""
    def power_law(x, a, b, c):
        return a * (x ** b) + c

    try:
        params, _ = curve_fit(power_law, x, y, p0=[1.0, -0.5, 0.0], maxfev=10000)
        a, b, c = params
        return c
    except Exception as e:
        print(f"Warning: fit_irreducible failed with error {e}")
        return np.min(y) - 0.01


def process_loc_loss(loc_loss, num_buckets=1):
    """Process location loss with exponential bucketing.

    Args:
        loc_loss: Loss per position of shape (T,)
        num_buckets: Number of buckets per octave (power of 2)

    Returns:
        positions_buckets: Bucketed position values
        loc_loss_buckets: Bucketed loss values (with irreducible subtracted)
        irreducible_loss: Estimated irreducible loss
    """
    T, = loc_loss.shape
    positions = jnp.arange(T) + 1  # 1-indexed positions

    # Create exponentially-spaced buckets
    seg_labels = jnp.ceil(num_buckets * jnp.log2(positions)).astype(jnp.int32)
    num_segs = jnp.max(seg_labels) + 1

    # Bucket the losses using geometric mean
    loc_loss_buckets = segexpmeanlog(loc_loss, seg_labels, num_segs)
    positions_buckets = segexpmeanlog(positions, seg_labels, num_segs)

    # Remove non-finite values
    non_nan = jnp.isfinite(loc_loss_buckets)
    safe_pos_buck = np.array(positions_buckets[non_nan])
    safe_loc_loss_buck = np.array(loc_loss_buckets[non_nan])

    # Fit power law to middle portion to estimate irreducible loss
    N = len(safe_pos_buck)
    #cut_initial = N // 3
    #cut_final = N // 3
    cut_initial = N // 2
    cut_final = -N
    if N > 6:
        try:
            irreducible_loss = fit_irreducible(
                safe_pos_buck[cut_initial:-cut_final],
                safe_loc_loss_buck[cut_initial:-cut_final]
            )
        except Exception as e:
            print(f"Warning: Using min loss as irreducible_loss")
            irreducible_loss = np.min(safe_loc_loss_buck) - 0.01
    else:
        irreducible_loss = np.min(safe_loc_loss_buck) - 0.01

    # Ensure irreducible is below minimum
    irreducible_loss = min(irreducible_loss, np.min(safe_loc_loss_buck) - 0.01)

    # Subtract irreducible loss to get reducible loss
    safe_loc_loss_buck = safe_loc_loss_buck - irreducible_loss

    return safe_pos_buck, safe_loc_loss_buck, irreducible_loss


def prepare_pg19_data(
    tokenizer,
    max_length: int = 2048,
    num_samples: int = 100,
):
    """Prepare PG-19 dataset for evaluation."""
    print("Loading PG-19 dataset...")
    dataset = load_dataset('emozilla/pg19', split='test', trust_remote_code=True, num_proc=16)

    samples = []
    for i, example in enumerate(tqdm(dataset, desc="Tokenizing")):
        if i >= num_samples:
            break

        text = example['text']
        tokens = tokenizer.encode(text, add_special_tokens=False)

        # Split into chunks of max_length
        for j in range(0, len(tokens), max_length):
            chunk = tokens[j:j + max_length]
            if len(chunk) == max_length:
                samples.append(chunk)

    return samples


def main():
    parser = argparse.ArgumentParser(description='Analyze loss by position for Llama 3.2 1B')
    parser.add_argument(
        '--checkpoint_dir',
        type=str,
        default='checkpoints/llama-3.2-1b-flax',
        help='Checkpoint directory',
    )
    parser.add_argument(
        '--max_length',
        type=int,
        default=2**15,
        help='Maximum sequence length',
    )
    parser.add_argument(
        '--num_samples',
        type=int,
        default=400,
        help='Number of samples to evaluate',
    )
    parser.add_argument(
        '--output',
        type=str,
        default='loss_by_position.png',
        help='Output plot filename',
    )
    parser.add_argument(
        '--use_fma_attention',
        action='store_true',
        help='Enable FMA attention approximation',
    )
    parser.add_argument(
        '--fma_block_size',
        type=int,
        default=None,
        help='Block size for FMA approximation (default: use config value)',
    )
    parser.add_argument(
        '--fma_num_clusters',
        type=int,
        default=None,
        help='Number of clusters for FMA approximation (default: use config value)',
    )
    parser.add_argument(
        '--fma_num_retrievals',
        type=int,
        default=None,
        help='Number of retrievals for FMA approximation (default: use config value)',
    )
    parser.add_argument(
        '--fma_bidiagonal',
        action='store_true',
        help='Use bidiagonal approximation for FMA',
    )

    args = parser.parse_args()

    print(f"Loading model from {args.checkpoint_dir}")
    model, params, config = load_model(args.checkpoint_dir)

    # Override FMA config from command-line arguments
    if args.use_fma_attention:
        config.use_fma_attention = True
        print("Enabling FMA attention")
    else:
        config.use_fma_attention = False
        print("Using standard attention")

    if args.fma_block_size is not None:
        config.fma_block_size = args.fma_block_size
        print(f"FMA block size: {config.fma_block_size}")

    if args.fma_num_clusters is not None:
        config.fma_num_clusters = args.fma_num_clusters
        print(f"FMA num clusters: {config.fma_num_clusters}")

    if args.fma_num_retrievals is not None:
        config.fma_num_retrievals = args.fma_num_retrievals
        print(f"FMA num retrievals: {config.fma_num_retrievals}")

    if args.fma_bidiagonal:
        config.fma_bidiagonal = True
        print("FMA bidiagonal: enabled")

    # Load tokenizer
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B')
    eos_token_id = tokenizer.eos_token_id
    print(f"EOS token ID: {eos_token_id}")

    # Prepare data
    samples = prepare_pg19_data(
        tokenizer,
        max_length=args.max_length,
        num_samples=args.num_samples,
    )
    print(f"Prepared {len(samples)} samples of length {args.max_length}")

    # JIT compile for speed
    jit_compute_loss = jax.jit(lambda x: compute_per_position_loss(model, params, x, eos_token_id))

    # Accumulate per-position losses and masks across all samples
    print("Computing per-position losses...")
    accumulated_loss = jnp.zeros(args.max_length - 1)
    accumulated_mask = jnp.zeros(args.max_length - 1)

    for sample in tqdm(samples, desc="Processing samples"):
        input_ids = jnp.array([sample])
        loss_per_pos, mask = jit_compute_loss(input_ids)
        loss_per_pos = loss_per_pos[0]  # Remove batch dim
        mask = mask[0]  # Remove batch dim

        # Accumulate masked losses
        accumulated_loss += loss_per_pos * mask
        accumulated_mask += mask

    # Average across samples using mask: sum(loss) / sum(mask)
    # Add small epsilon to avoid division by zero
    avg_loss_per_position = accumulated_loss / jnp.maximum(accumulated_mask, 1e-10)

    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    # Add attention type to title
    attn_type = "FMA" if config.use_fma_attention else "Standard"

    # Plot 1: Coarse bucketing (1 bucket per octave)
    x1, y1, irred1 = process_loc_loss(np.array(avg_loss_per_position), num_buckets=1)
    ax1.plot(x1, y1, 'o-', linewidth=2, markersize=6)

    # Label the last 5 points with their exact loss values (reducible + irreducible)
    for i in range(max(0, len(x1) - 5), len(x1)):
        total_loss = y1[i] + irred1
        ax1.annotate(f'{total_loss:.3f}',
                     xy=(x1[i], y1[i]),
                     xytext=(5, 5),
                     textcoords='offset points',
                     fontsize=8,
                     bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))

    ax1.set_title(f"{attn_type} - Loss by Context Depth - Coarse (irreducible: {irred1:.3f})")
    ax1.set_xlabel('Token Position')
    ax1.set_ylabel('Reducible Loss (nats)')
    ax1.set_yscale('log')
    ax1.set_xscale('log')
    ax1.grid(True, alpha=0.3)

    # Plot 2: Fine bucketing (4 buckets per octave)
    x2, y2, irred2 = process_loc_loss(np.array(avg_loss_per_position), num_buckets=4)
    ax2.plot(x2, y2, 'o-', linewidth=2, markersize=4)
    ax2.set_title(f"{attn_type} - Loss by Context Depth - Fine (irreducible: {irred2:.3f})")
    ax2.set_xlabel('Token Position')
    ax2.set_ylabel('Reducible Loss (nats)')
    ax2.set_yscale('log')
    ax2.set_xscale('log')
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(args.output, dpi=300, bbox_inches='tight')
    print(f"\nPlot saved to {args.output}")

    # Print summary statistics
    total_tokens = accumulated_mask.sum()
    total_possible = len(samples) * (args.max_length - 1)
    mask_ratio = total_tokens / total_possible

    print(f"\nSummary:")
    print(f"  Total tokens (after masking): {int(total_tokens)} / {total_possible} ({mask_ratio*100:.1f}%)")
    print(f"  Average loss: {float(jnp.mean(avg_loss_per_position)):.4f} nats")
    print(f"  Average perplexity: {float(jnp.exp(jnp.mean(avg_loss_per_position))):.4f}")
    print(f"  Irreducible loss: {irred2:.4f} nats")
    print(f"  First token loss: {float(avg_loss_per_position[0]):.4f} nats")
    print(f"  Last token loss: {float(avg_loss_per_position[-1]):.4f} nats")
    print(f"  Context length: {args.max_length}")


if __name__ == '__main__':
    main()
