import argparse
import copy
import json
import os

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from flax.core import freeze, unfreeze
from flax.jax_utils import replicate, unreplicate
from flax.serialization import from_bytes
from flax.training.train_state import TrainState
from jax import lax, random
from tqdm import tqdm
from collections import Counter


from src.matching_utils import matching_attn_all_heads_permu
from src.utils import lerp

# --- Evaluation logic for BERT (adapted for pmap) ---
def eval_step(state, batch):
    """A single evaluation step compatible with jax.pmap."""
    logits = state.apply_fn({'params': state.params}, batch['inputs'], batch['pad_mask'],
                            deterministic=True)
    one_hot_labels = jax.nn.one_hot(batch['labels'], num_classes=logits.shape[-1])
    per_example_loss = -jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)
    
    # Use valid_mask to ignore padded examples during metric calculation
    numerator = lax.psum(jnp.sum(per_example_loss * batch['valid_mask']), axis_name='batch')
    denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    loss = numerator / denominator

    per_example_acc = (jnp.argmax(logits, -1) == batch['labels']).astype(jnp.float32)
    acc_numerator = lax.psum(jnp.sum(per_example_acc * batch['valid_mask']), axis_name='batch')
    acc_denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    accuracy = acc_numerator / acc_denominator

    metrics = {'loss': loss, 'accuracy': accuracy}
    return metrics


def main():
    parser = argparse.ArgumentParser(description="Attention head permutation matching for AGNews BERT models.")
    parser.add_argument("--rope-use", action='store_true', help="Use RoPE-enabled BERT model and matching if this flag is present")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned BERT model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned BERT model checkpoint")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of attention heads")
    parser.add_argument("--embedding-dim", type=int, default=48, help="Dimension of token embeddings")
    parser.add_argument("--hidden-dim", type=int, default=192, help="Dimension of the feed-forward network")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Which attention layer to match (single layer index)")
    parser.add_argument("--plot-path", type=str, default="./plots/agnews_head_matching", help="Path to save plots and results")
    parser.add_argument("--data-path", type=str, default="/data/agnews", help="Path to AGNews data")
    parser.add_argument("--batch-size", type=int, default=2048, help="Batch size for evaluation")
    args = parser.parse_args()

    # Conditional imports based on rope-use argument
    if not args.rope_use:
        from .agnews_bert_train import BertModel, pad_sequences, load_agnews_data, get_mha_inputs
    else:
        # Assumes a agnews_bert_train_rope.py exists with a RoPE-enabled model
        from .agnews_bert_train_rope import BertModel, pad_sequences, load_agnews_data, get_mha_inputs

    # --- Configuration ---
    class Config:
        pass
    config = Config()
    config.rope_use = args.rope_use
    config.seed = args.seed
    config.num_layers = args.num_layers
    config.finetune_layer_which = [int(idx) for idx in args.finetune_layer_which.split(",")]
    assert len(config.finetune_layer_which) == 1, "finetune_layer_which must have exactly one layer for this script"
    config.num_heads = args.num_heads
    config.plot_path = args.plot_path
    os.makedirs(config.plot_path, exist_ok=True)
    
    # --- JAX Device Setup ---
    num_devices = jax.local_device_count()
    print(f"Using {num_devices} devices.")
    assert args.batch_size % num_devices == 0, f"Batch size ({args.batch_size}) must be divisible by the number of devices ({num_devices})."
    per_device_batch_size = args.batch_size // num_devices

    # --- Model and Data Loading ---
    print("Loading AGNews data...")
    (x_train, y_train), (x_test, y_test), num_classes, PAD = load_agnews_data(args.data_path)
    
    # Build vocabulary from training data (must be identical to training)
    print("Building vocabulary...")
    all_words = [word for seq in x_train for word in seq]
    word_counts = Counter(all_words)
    vocab_size = 15000 - 2  # Reserve 0 for PAD, 1 for UNK
    most_common = word_counts.most_common(vocab_size)
    vocab = {"<PAD>": 0, "<UNK>": 1}
    for idx, (word, _) in enumerate(most_common, 2):
        vocab[word] = idx
    
    def to_ids(seq):
        return [vocab.get(word, 1) for word in seq]

    train_ds = {
        "token_ids": [to_ids(seq) for seq in x_train],
        "labels": np.array(y_train, dtype=np.int32)
    }
    test_ds = {
        "token_ids": [to_ids(seq) for seq in x_test],
        "labels": np.array(y_test, dtype=np.int32)
    }

    model = BertModel(
        num_layers=config.num_layers, 
        num_heads=config.num_heads,
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        num_classes=num_classes,
        vocab_size=15000,
        max_seq_len=100
    )
    
    rng = random.PRNGKey(config.seed)

    def load_model_params(filepath, model, rng):
        dummy_input = jnp.zeros((1, model.max_seq_len), dtype=jnp.int32)
        init_rng, dropout_rng = random.split(rng)
        params_struct = model.init({'params': init_rng, 'dropout': dropout_rng}, dummy_input)['params']
        with open(filepath, "rb") as fh:
            return from_bytes(params_struct, fh.read())

    rng, model_rng_a = random.split(rng)
    model_a_params = load_model_params(args.model_a, model, model_rng_a)
    rng, model_rng_b = random.split(rng)
    model_b_params = load_model_params(args.model_b, model, model_rng_b)

    lambdas = jnp.linspace(0, 1, num=9)

    # --- Get Activations ---
    activation_batch_size = 1024
    rng, activation_rng = random.split(rng)
    print(f"Computing activations from models on a batch of {activation_batch_size} training examples...")
    # AGNews get_mha_inputs needs model, params, dataset, rng, batch_size, vocab, pad_value
    activations_a = get_mha_inputs(model, model_a_params, train_ds, activation_rng, activation_batch_size, vocab, PAD)
    activations_b = get_mha_inputs(model, model_b_params, train_ds, activation_rng, activation_batch_size, vocab, PAD)
    print(f"Got activations for {len(activations_a)} layers from each model.")

    # --- Interpolation & Evaluation Function ---
    def dataset_loss_and_accuracy(params, model_apply_fn, dataset, batch_size, PAD, num_devices, per_device_batch_size):
        state = TrainState.create(apply_fn=model_apply_fn, params=params, tx=optax.sgd(1e-3)) # optimizer is a dummy
        state = replicate(state)
        p_eval_step = jax.pmap(eval_step, axis_name='batch')
        num_examples = len(dataset['labels'])
        batch_metrics = []
        
        def shard(array):
            return jnp.reshape(array, (num_devices, per_device_batch_size) + array.shape[1:])

        for i in tqdm(range(0, num_examples, batch_size), desc="Evaluating", leave=False):
            end = min(i + batch_size, num_examples)
            current_batch_size = end - i
            pad_size = batch_size - current_batch_size
            
            token_ids = dataset["token_ids"][i:end]
            labels = dataset["labels"][i:end]
            
            if pad_size > 0:
                token_ids += [[PAD]] * pad_size
                labels = np.concatenate([labels, np.zeros(pad_size, dtype=labels.dtype)])

            x_batch = pad_sequences(token_ids, pad_value=PAD)
            x_batch = jnp.array(x_batch)
            pad_mask = (x_batch != PAD)[:, None, None, :]
            labels = jnp.array(labels)
            valid_mask = jnp.array([1.0] * current_batch_size + [0.0] * pad_size)
            
            batch = {
                'inputs': shard(x_batch),
                'pad_mask': shard(pad_mask),
                'labels': shard(labels),
                'valid_mask': shard(valid_mask)
            }
            
            metrics = p_eval_step(state, batch)
            metrics = unreplicate(metrics)
            batch_metrics.append(metrics)

        mean_metrics = {k: np.mean([m[k] for m in batch_metrics]) for k in batch_metrics[0]}
        return mean_metrics['loss'], mean_metrics['accuracy']

    def compute_interpolation(model_a, model_b_target, lambdas, desc="Interpolation"):
        metrics = {"Train Loss": [], "Test Loss": [], "Train Acc": [], "Test Acc": []}
        for lam in tqdm(lambdas, desc=desc):
            p_interp = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b_target)))
            # Skip train evaluation for speed
            train_loss, train_acc = 0.0, 0.0
            test_loss, test_acc = dataset_loss_and_accuracy(
                p_interp, model.apply, test_ds, args.batch_size, PAD, num_devices, per_device_batch_size
            )
            metrics["Train Loss"].append(float(f"{train_loss:.4f}"))
            metrics["Test Loss"].append(float(f"{test_loss:.4f}"))
            metrics["Train Acc"].append(float(f"{train_acc:.4f}"))
            metrics["Test Acc"].append(float(f"{test_acc:.4f}"))
        return metrics

    # --- Run Matching and Interpolations ---
    # 1. Naive Interpolation
    naive_results = compute_interpolation(model_a_params, model_b_params, lambdas, desc="Naive Interpolation")

    # 2. Get Aligned Parameters for All Permutations
    print("\nStarting attention head matching for all permutations...")
    params_dict, heads_objective_values, heads_permutation_sol = matching_attn_all_heads_permu(
        rng, model_a_params, model_b_params, config.finetune_layer_which, config.num_heads, config.plot_path,
        rope_use=config.rope_use, activations_a=activations_a, activations_b=activations_b,
    )

    # 3. Compute Interpolation for All Permutations
    all_perm_results = {}
    for setting in params_dict:
        all_perm_results[setting] = {}
        for perm_str, aligned_result in params_dict[setting].items():
            aligned_attn_block = aligned_result.get('aligned_params', aligned_result)
            params_b_full_aligned = copy.deepcopy(unfreeze(model_b_params))
            
            layer_key = f'TransformerEncoderLayer_{config.finetune_layer_which[0]}'
            attention_key = 'MultiHeadDotProductAttention_0'
            params_b_full_aligned[layer_key][attention_key] = aligned_attn_block
            
            results = compute_interpolation(model_a_params, freeze(params_b_full_aligned), lambdas, desc=f"{setting} {perm_str}")
            all_perm_results[setting][perm_str] = results

    # --- Consolidate and Save All Results ---
    highlighted_perm_str = {}
    for alpha in heads_permutation_sol.keys():
        col_ind = heads_permutation_sol[alpha][1]
        perm_str = str([int(_) for _ in col_ind])
        highlighted_perm_str[alpha] = perm_str

    final_results = {
        "naive_interpolation": naive_results,
        "all_permutation_interpolations": all_perm_results,
        "selected_permutations_by_method": highlighted_perm_str,
        "objective_values": heads_objective_values
    }

    output_file = os.path.join(config.plot_path, "matching_results.json")
    with open(output_file, "w") as f:
        json.dump(final_results, f, indent=2)
    print(f"\nAll results saved to {output_file}")


    # --- Plotting ---
    print("Generating plots...")
    lambda_values = np.array(lambdas)
    metrics = ["Train Loss", "Test Loss", "Train Acc", "Test Acc"]
    positions = [(0, 0), (0, 1), (1, 0), (1, 1)]

    plt.style.use('tableau-colorblind10')

    for setting in ['init_ortho_opt', 'init_ortho_no_opt']:
        if setting not in all_perm_results: continue
        
        for include_naive in [True, False]:
            fig, axs = plt.subplots(2, 2, figsize=(14, 11))
            fig.suptitle(f'AGNews BERT Interpolation Plots: {setting}' + (' with Naive' if include_naive else ''), fontsize=16)
            
            for metric, pos in zip(metrics, positions):
                ax = axs[pos]
                
                # Handle skipped train evaluation in plots
                if "Train" in metric:
                    ax.text(0.5, 0.5, 'Train evaluation skipped for speed', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, alpha=0.5)
                    ax.set_title(metric)
                    ax.grid(True, linestyle='--', alpha=0.6)
                    continue

                # Plot all permutations in light gray
                for perm_str, perm_results in all_perm_results[setting].items():
                    ax.plot(lambda_values, perm_results[metric], color='gray', alpha=0.4, linewidth=1.5)
                
                ax.plot([], [], color='gray', alpha=0.5, label='Other Permutations') # For legend

                if include_naive:
                    ax.plot(lambda_values, naive_results[metric], label='Naive', color='blue', linestyle='--')
                
                # Plot highlighted permutations
                colors = plt.cm.viridis(np.linspace(0, 1, len(highlighted_perm_str)))
                for idx, (alpha, perm_str) in enumerate(highlighted_perm_str.items()):
                    if perm_str in all_perm_results[setting]:
                        ax.plot(lambda_values, all_perm_results[setting][perm_str][metric], 
                                label=f"$\\alpha$={alpha}: {perm_str}", color=colors[idx])
                
                ax.set_title(metric)
                ax.set_xlabel(r"$\lambda$")
                ax.set_ylabel("Value")
                ax.grid(True, linestyle='--', alpha=0.6)

            handles, labels = axs[1,1].get_legend_handles_labels()
            fig.legend(handles, labels, loc='upper left')
            plt.tight_layout(rect=[0, 0, 1, 0.96])
            
            plot_name = f"interpolation_agnews_{setting}" + ("" if include_naive else "_no_naive") + ".png"
            save_path = os.path.join(config.plot_path, plot_name)
            plt.savefig(save_path)
            plt.close()
    
    print(f"Plots saved in {config.plot_path}")

if __name__ == "__main__":
    main()