"""Compare different FMA attention configurations by analyzing loss vs position."""

import argparse
import pickle
from pathlib import Path
from typing import Dict, List

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from functools import partial

from fma_llama.model.config import LlamaConfig
from fma_llama.model.llama import LlamaForCausalLM

mesh = Mesh(mesh_utils.create_device_mesh((1, jax.device_count())), ('data', 'model'))
jax.set_mesh(mesh)


# Define FMA configurations to compare
FMA_CONFIGS = [
    {
        'name': 'Standard Attention',
        'use_fma': False,
        'block_size': None,
        'num_clusters': None,
        'num_retrievals': None,
        'bidiagonal': False,
        'dipole': False,
    },
    {
        'name': 'FMA-8k-128-0',
        'use_fma': True,
        'block_size': 2**13,
        'num_clusters': 128,
        'num_retrievals': None,
        'bidiagonal': False,
        'dipole': False,
    },
    {
        'name': 'FMA-8k-128-0-Dip',
        'use_fma': True,
        'block_size': 2**13,
        'num_clusters': 128,
        'num_retrievals': None,
        'bidiagonal': False,
        'dipole': True,
    },
    {
        'name': 'FMA-8k-128-8',
        'use_fma': True,
        'block_size': 2**13,
        'num_clusters': 128,
        'num_retrievals': 8,
        'bidiagonal': False,
        'dipole': False,
    },
]
FMA_CONFIGS = FMA_CONFIGS


def shard_params(params, mesh):
    """Shard loaded parameters according to model annotations.

    Args:
        params: Loaded parameter tree
        mesh: Device mesh

    Returns:
        Sharded parameters
    """
    from jax.sharding import NamedSharding, PartitionSpec as P
    import jax

    def get_sharding_for_param(path, array):
        """Determine sharding based on parameter path."""
        path_str = '/'.join(str(p) for p in path)

        # MLP gate/up projections: shard output dim
        if 'gate_proj' in path_str or 'up_proj' in path_str:
            if 'kernel' in path_str:
                return P(None, 'model')
        # MLP down projection: shard input dim
        elif 'down_proj' in path_str:
            if 'kernel' in path_str:
                return P('model', None)
        # Attention projections: shard on heads/hidden dim
        elif 'Dense' in path_str:  # Q/K/V/O projections
            if 'kernel' in path_str:
                return P(None, 'model')
        # Embeddings: shard hidden dim
        elif 'Embed' in path_str or 'embedding' in path_str:
            return P(None, 'model')

        # Default: replicate
        return P(None, None) if len(array.shape) == 2 else P(None,)

    def shard_tree(tree, path=()):
        """Recursively shard parameter tree."""
        if isinstance(tree, dict):
            return {k: shard_tree(v, path + (k,)) for k, v in tree.items()}
        elif isinstance(tree, jnp.ndarray):
            spec = get_sharding_for_param(path, tree)
            sharding = NamedSharding(mesh, spec)
            return jax.device_put(tree, sharding)
        else:
            return tree

    return shard_tree(params)


def load_model(checkpoint_dir: str, mesh=None):
    """Load Flax model and config from checkpoint directory.

    Args:
        checkpoint_dir: Path to checkpoint
        mesh: Optional device mesh for sharding params
    """
    checkpoint_path = Path(checkpoint_dir)

    with open(checkpoint_path / "config.pkl", 'rb') as f:
        config = pickle.load(f)

    with open(checkpoint_path / "flax_params.pkl", 'rb') as f:
        params = pickle.load(f)

    model = LlamaForCausalLM(config)

    # Shard params if mesh provided
    if mesh is not None:
        print("Sharding parameters across devices...")
        params = shard_params(params, mesh)

    return model, params, config


def compute_per_position_loss(
    model,
    params,
    input_ids: jnp.ndarray,
    eos_token_id: int,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Compute per-position cross-entropy loss with masking.

    Args:
        model: Flax model
        params: Model parameters
        input_ids: Input token IDs of shape (batch_size, seq_len)
        eos_token_id: Token ID for end-of-sequence

    Returns:
        Tuple of (loss_per_position, mask) both of shape (batch_size, seq_len-1)
        mask is True for tokens up to and including EOS, False after
    """
    # Get logits
    logits = model.apply(params, input_ids)
    logits = jax.lax.with_sharding_constraint(logits, P('data', 'model', None))

    # Compute cross-entropy loss per position
    # Shift logits and labels for language modeling
    #shift_logits = logits[:, :-1, :]  # (batch, seq_len-1, vocab)
    shift_labels = input_ids[:, 1:]    # (batch, seq_len-1)
    #shift_labels = jnp.concatenate(shift_labels, jnp.zeros((shift_labels.shape[0], 1), dtype=shift_labels.dtype), axis=1)
    shift_labels = jnp.concatenate((input_ids[:, 1:], jnp.zeros((input_ids.shape[0], 1), dtype=input_ids.dtype)), axis=1)

    # Create mask: True up to and including first EOS, False after
    # Find first EOS position in each sequence
    is_eos = shift_labels == eos_token_id
    # cumsum to mark all positions after first EOS
    eos_and_after = jnp.cumsum(is_eos, axis=1)
    # mask is True up to first EOS (where eos_and_after < 1)
    mask = eos_and_after < 1

    # Compute per-position loss
    log_probs = jax.nn.log_softmax(logits, axis=-1)  # (batch, seq_len, vocab)
    loss_per_position = -jnp.take_along_axis(
        log_probs,
        shift_labels[:, :, None],
        axis=-1
    ).squeeze(-1)  # (batch, seq_len)
    loss_per_position = loss_per_position[:, :-1]  # (batch, seq_len-1)
    mask = mask[:, :-1]  # (batch, seq_len-1)
    loss_per_position = jax.device_get(loss_per_position)
    mask = jax.device_get(mask)

    return loss_per_position, mask


def segexpmeanlog(values, segment_ids, num_segments):
    """Compute geometric mean of values within each segment."""
    logvalues = jnp.log(values)
    segsum = jax.ops.segment_sum(logvalues, segment_ids, num_segments)
    segcnt = jax.ops.segment_sum(jnp.ones_like(values), segment_ids, num_segments)
    return jnp.exp(segsum / segcnt)


def fit_irreducible(x, y):
    """Fit power law to estimate irreducible loss."""
    def power_law(x, a, b, c):
        return a * (x ** b) + c

    try:
        params, _ = curve_fit(power_law, x, y, p0=[1.0, -0.5, 0.0], maxfev=10000)
        a, b, c = params
        return c
    except Exception as e:
        print(f"Warning: fit_irreducible failed with error {e}")
        return np.min(y) - 0.01

def cumexpmeanlog(values):
    """Compute cumulative geometric mean of values."""
    logvalues = jnp.log(values)
    cumsum = jnp.cumsum(logvalues)
    counts = jnp.arange(1, len(values) + 1)
    return jnp.exp(cumsum / counts)


def process_loc_loss(loc_loss, num_buckets=1, fixed_irreducible=None):
    """Process location loss with exponential bucketing.

    Args:
        loc_loss: Loss per position of shape (T,)
        num_buckets: Number of buckets per octave (power of 2)
        fixed_irreducible: If provided, use this as irreducible loss instead of computing

    Returns:
        positions_buckets: Bucketed position values
        loc_loss_buckets: Bucketed loss values (with irreducible subtracted)
        irreducible_loss: Estimated or provided irreducible loss
    """
    T, = loc_loss.shape
    positions = jnp.arange(T) + 1  # 1-indexed positions

    # Create exponentially-spaced buckets
    seg_labels = jnp.ceil(num_buckets * jnp.log2(positions)).astype(jnp.int32)
    num_segs = jnp.max(seg_labels) + 1

    # Bucket the losses using geometric mean
    loc_loss_buckets = segexpmeanlog(loc_loss, seg_labels, num_segs)
    positions_buckets = segexpmeanlog(positions, seg_labels, num_segs)

    loc_loss_buckets = cumexpmeanlog(loc_loss)
    positions_buckets = positions

    # Remove non-finite values
    non_nan = jnp.isfinite(loc_loss_buckets)
    safe_pos_buck = np.array(positions_buckets[non_nan])
    safe_loc_loss_buck = np.array(loc_loss_buckets[non_nan])

    # Compute or use provided irreducible loss
    if fixed_irreducible is not None:
        irreducible_loss = fixed_irreducible
    else:
        # Fit power law to middle portion to estimate irreducible loss
        N = len(safe_pos_buck)
        cut_initial = N // 2
        cut_final = N // 8
        if N > 6:
            try:
                irreducible_loss = fit_irreducible(
                    safe_pos_buck[cut_initial:-cut_final],
                    safe_loc_loss_buck[cut_initial:-cut_final]
                )
            except Exception as e:
                print(f"Warning: Using min loss as irreducible_loss")
                irreducible_loss = np.min(safe_loc_loss_buck) - 0.01
        else:
            irreducible_loss = np.min(safe_loc_loss_buck) - 0.01

        # Ensure irreducible is below minimum
        irreducible_loss = min(irreducible_loss, np.min(safe_loc_loss_buck) - 0.01)

    # Subtract irreducible loss to get reducible loss
    safe_loc_loss_buck = safe_loc_loss_buck - irreducible_loss

    return safe_pos_buck, safe_loc_loss_buck, irreducible_loss


def prepare_pg19_data(
    tokenizer,
    max_length: int = 2048,
    num_samples: int = 100,
):
    """Prepare PG-19 dataset for evaluation."""
    print("Loading PG-19 dataset...")
    dataset = load_dataset('emozilla/pg19', split='test', trust_remote_code=True, num_proc=16)

    samples = []
    for i, example in enumerate(tqdm(dataset, desc="Tokenizing")):
        if i >= num_samples:
            break

        text = example['text']
        tokens = tokenizer.encode(text, add_special_tokens=False)

        # Split into chunks of max_length
        for j in range(0, len(tokens), max_length):
            chunk = tokens[j:j + max_length]
            if len(chunk) == max_length:
                samples.append(chunk)

    return samples


def evaluate_config(
    model,
    params,
    config,
    fma_config: Dict,
    samples: List,
    eos_token_id: int,
    max_length: int,
) -> jnp.ndarray:
    """Evaluate a single FMA configuration.

    Args:
        model: Flax model
        params: Model parameters
        config: Base model config
        fma_config: FMA configuration dict
        samples: Tokenized samples
        eos_token_id: EOS token ID
        max_length: Maximum sequence length

    Returns:
        Average loss per position
    """
    # Apply FMA config
    config.use_fma_attention = fma_config['use_fma']
    if fma_config['use_fma']:
        config.fma_block_size = fma_config['block_size']
        config.fma_num_clusters = fma_config['num_clusters']
        config.fma_num_retrievals = fma_config['num_retrievals']
        config.fma_bidiagonal = fma_config['bidiagonal']

    # JIT compile for speed
    jit_compute_loss = jax.jit(
        partial(compute_per_position_loss, model, params, eos_token_id=eos_token_id),
        out_shardings=P(None, None)
    )

    # Accumulate per-position losses and masks across all samples
    print(f"Computing per-position losses for {fma_config['name']}...")
    accumulated_loss = jnp.zeros(max_length - 1)
    accumulated_mask = jnp.zeros(max_length - 1)

    input_sharding = NamedSharding(mesh, P('data', None))

    for sample in tqdm(samples, desc=f"Processing {fma_config['name']}"):
        input_ids = jnp.array([sample])
        input_ids = jax.device_put(input_ids, input_sharding)
        loss_per_pos, mask = jit_compute_loss(input_ids)

        # Materialize results to remove sharding metadata before indexing
        loss_per_pos = jax.device_get(loss_per_pos)
        mask = jax.device_get(mask)

        loss_per_pos = loss_per_pos[0]  # Remove batch dim
        mask = mask[0]  # Remove batch dim

        # Accumulate masked losses
        accumulated_loss += loss_per_pos * mask
        accumulated_mask += mask

    # Average across samples using mask: sum(loss) / sum(mask)
    avg_loss_per_position = accumulated_loss / jnp.maximum(accumulated_mask, 1e-10)

    return avg_loss_per_position


def main():
    parser = argparse.ArgumentParser(description='Compare FMA attention configurations')
    parser.add_argument(
        '--checkpoint_dir',
        type=str,
        default='checkpoints/llama-3.2-1b-flax',
        help='Checkpoint directory',
    )
    parser.add_argument(
        '--max_length',
        type=int,
        default=2**16,
        help='Maximum sequence length',
    )
    parser.add_argument(
        '--num_samples',
        type=int,
        default=100,
        help='Number of samples to evaluate',
    )
    parser.add_argument(
        '--output',
        type=str,
        default='fma_comparison.png',
        help='Output plot filename',
    )
    parser.add_argument(
        '--num_buckets',
        type=int,
        default=4,
        help='Number of buckets per octave for plotting (default: 4)',
    )

    args = parser.parse_args()

    print(f"Loading model from {args.checkpoint_dir}")
    model, params, base_config = load_model(args.checkpoint_dir, mesh=mesh)

    # Load tokenizer
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B')
    eos_token_id = tokenizer.eos_token_id
    print(f"EOS token ID: {eos_token_id}")

    # Prepare data once
    samples = prepare_pg19_data(
        tokenizer,
        max_length=args.max_length,
        num_samples=args.num_samples,
    )
    print(f"Prepared {len(samples)} samples of length {args.max_length}")

    # Evaluate each configuration
    results = {}
    for fma_config in FMA_CONFIGS:
        avg_loss = evaluate_config(
            model,
            params,
            base_config,
            fma_config,
            samples,
            eos_token_id,
            args.max_length,
        )
        results[fma_config['name']] = avg_loss

    # Compute irreducible loss from Standard Attention (baseline)
    print("\nComputing baseline irreducible loss from Standard Attention...")
    standard_attn_name = 'Standard Attention'
    if standard_attn_name in results:
        _, _, baseline_irreducible = process_loc_loss(
            np.array(results[standard_attn_name]),
            num_buckets=args.num_buckets
        )
        print(f"Baseline irreducible loss: {baseline_irreducible:.4f}")
    else:
        print("Warning: Standard Attention not found, computing irreducible for each config")
        baseline_irreducible = None
    baseline_irreducible = 1.8585

    # Create comparison plot
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))

    colors = ['black', 'blue', 'red', 'green', 'purple', 'orange']
    markers = ['o', 's', '^', 'D', 'v', 'p']

    for idx, (config_name, avg_loss) in enumerate(results.items()):
        # Process with bucketing using baseline irreducible loss
        x, y, irred = process_loc_loss(
            np.array(avg_loss),
            num_buckets=args.num_buckets,
            fixed_irreducible=baseline_irreducible
        )

        # Plot
        ax.plot(
            x, y,
            #marker=markers[idx % len(markers)],
            linestyle='-',
            linewidth=2,
            #markersize=5,
            color=colors[idx % len(colors)],
            label=f'{config_name}',
            alpha=0.8,
        )

    title = 'FMA Attention Configuration Comparison - Loss by Context Depth'
    if baseline_irreducible is not None:
        title += f'\nBaseline Irreducible Loss: {baseline_irreducible:.4f} nats'
    ax.set_title(title)
    ax.set_xlabel('Token Position')
    ax.set_ylabel('Reducible Loss (nats)')
    ax.set_yscale('log')
    ax.set_xscale('log')
    ax.grid(True, alpha=0.3)
    ax.legend(loc='best', fontsize=10)

    plt.tight_layout()
    plt.savefig(args.output, dpi=300, bbox_inches='tight')
    print(f"\nComparison plot saved to {args.output}")

    # Print summary statistics
    print(f"\nSummary Statistics:")
    print(f"  Context length: {args.max_length}")
    print(f"  Num samples: {len(samples)}")
    if baseline_irreducible is not None:
        print(f"  Baseline irreducible loss: {baseline_irreducible:.4f} nats")
    print()
    for config_name, avg_loss in results.items():
        print(f"{config_name}:")
        print(f"  Average loss: {float(jnp.mean(avg_loss)):.4f} nats")
        print(f"  Average perplexity: {float(jnp.exp(jnp.mean(avg_loss))):.4f}")
        print(f"  First token loss: {float(avg_loss[0]):.4f} nats")
        print(f"  Last token loss: {float(avg_loss[-1]):.4f} nats")
        print()


if __name__ == '__main__':
    main()
