import argparse
import os
import jax
import jax.numpy as jnp
import numpy as np
import flax
from flax.core import freeze, unfreeze
from flax.training.train_state import TrainState
from flax.serialization import from_bytes
from jax import random, lax
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
import optax
from flax.jax_utils import replicate, unreplicate

# --- MODIFICATION: Import both matching functions ---
# Import both standard and RoPE-enabled attention matching functions
from ..matching_utils import matching_attn, matching_attn_rope
from ..utils import lerp


# Modified eval_step for pmap compatibility
def eval_step(state, batch):
    logits = state.apply_fn({'params': state.params}, batch['inputs'], batch['pad_mask'],
                            deterministic=True)
    one_hot_labels = jax.nn.one_hot(batch['labels'], num_classes=logits.shape[-1])
    per_example_loss = -jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)
    numerator = lax.psum(jnp.sum(per_example_loss * batch['valid_mask']), axis_name='batch')
    denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    loss = numerator / denominator

    per_example_acc = (jnp.argmax(logits, -1) == batch['labels']).astype(jnp.float32)
    acc_numerator = lax.psum(jnp.sum(per_example_acc * batch['valid_mask']), axis_name='batch')
    acc_denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    accuracy = acc_numerator / acc_denominator

    metrics = {'loss': loss, 'accuracy': accuracy}
    return metrics

def main():
    parser = argparse.ArgumentParser(description="Weight matching and interpolation for IMDB BERT models.")
    parser.add_argument("--rope-use", action='store_true', help="Use RoPE-enabled BERT model and matching if this flag is present")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned BERT model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned BERT model checkpoint")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers in the models")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of attention heads in the models")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Comma-separated indices ('all' for all) of attention layers to match")
    parser.add_argument("--batch-size", type=int, default=512, help="Batch size for evaluation")
    parser.add_argument("--plot-path", type=str, default="./plots/imdbreview_attn_matching", help="Path to save plots and results")
    parser.add_argument("--data-path", type=str, required=True, help="Path to IMDB CSV data file")
    args = parser.parse_args()

    # --- MODIFICATION: Conditional imports for IMDB model and data functions ---
    if not args.rope_use:
        from .imdbreview_bert_train import BertModel, pad_sequences, load_imdb_review_data, get_mha_inputs
    else:
        from .imdbreview_bert_train_rope import BertModel, pad_sequences, load_imdb_review_data, get_mha_inputs

    def dataset_loss_and_accuracy(params, model_apply_fn, dataset, batch_size, PAD, num_devices, per_device_batch_size):
        """
        Computes the loss and accuracy of a model on a given dataset.
        """
        state = TrainState.create(apply_fn=model_apply_fn, params=params, tx=optax.sgd(1e-3))
        state = replicate(state)
        
        p_eval_step = jax.pmap(eval_step, axis_name='batch')

        num_examples = len(dataset['labels'])
        batch_metrics = []
        
        def shard(array):
            return jnp.reshape(array, (num_devices, per_device_batch_size) + array.shape[1:])

        for i in tqdm(range(0, num_examples, batch_size), desc="Evaluating", leave=False):
            end = min(i + batch_size, num_examples)
            current_batch_size = end - i
            pad_size = batch_size - current_batch_size
            
            token_ids = dataset["token_ids"][i:end]
            labels = dataset["labels"][i:end]
            
            if pad_size > 0:
                token_ids += [[PAD]] * pad_size # Use a list containing PAD token for padding
                labels = np.concatenate([labels, np.zeros(pad_size, dtype=labels.dtype)])

            x_batch = pad_sequences(token_ids, pad_value=PAD)
            x_batch = jnp.array(x_batch)
            pad_mask = (x_batch != PAD)[:, None, None, :]
            labels = jnp.array(labels)
            valid_mask = jnp.array([1.0] * current_batch_size + [0.0] * pad_size)
            
            batch = {
                'inputs': shard(x_batch),
                'pad_mask': shard(pad_mask),
                'labels': shard(labels),
                'valid_mask': shard(valid_mask)
            }
            
            metrics = p_eval_step(state, batch)
            metrics = unreplicate(metrics)
            batch_metrics.append(metrics)

        mean_metrics = {k: np.mean([m[k] for m in batch_metrics]) for k in batch_metrics[0]}
        return mean_metrics['loss'], mean_metrics['accuracy']

    def compute_interpolation(model_a_params, model_b_params, model_apply_fn, lambdas, train_ds, test_ds, batch_size, PAD, num_devices, per_device_batch_size, desc="Interpolation"):
        """
        Computes metrics for linear interpolation between two models.
        """
        train_loss_interp, test_loss_interp = [], []
        train_acc_interp, test_acc_interp = [], []

        for lam in tqdm(lambdas, desc=desc):
            p_interp = freeze(lerp(lam, unfreeze(model_a_params), unfreeze(model_b_params)))
            
            train_loss, train_acc = 0, 0 # Skipping train set evaluation for speed
            test_loss, test_acc = dataset_loss_and_accuracy(p_interp, model_apply_fn, test_ds, batch_size, PAD, num_devices, per_device_batch_size)
            
            train_loss_interp.append(train_loss)
            test_loss_interp.append(test_loss)
            train_acc_interp.append(train_acc)
            test_acc_interp.append(test_acc)
            
        return {
            "Train Loss": [float(f"{x:.4f}") for x in train_loss_interp],
            "Test Loss": [float(f"{x:.4f}") for x in test_loss_interp],
            "Train Acc": [float(f"{x:.4f}") for x in train_acc_interp],
            "Test Acc": [float(f"{x:.4f}") for x in test_acc_interp]
        }

    num_devices = jax.local_device_count()
    print(f"Using {num_devices} devices")
    assert args.batch_size % num_devices == 0, "Batch size must be divisible by the number of devices"
    per_device_batch_size = args.batch_size // num_devices

    os.makedirs(args.plot_path, exist_ok=True)
    rng = random.PRNGKey(args.seed)

    # --- Data Loading and Preprocessing ---
    print("Loading and preprocessing IMDB data...")
    (x_train, y_train), (x_test, y_test), num_classes, PAD = load_imdb_review_data(args.data_path, max_seq_len=256)
    
    # Data is already tokenized and converted to IDs by the loader
    train_ds = {
        "token_ids": x_train,
        "labels": np.array(y_train, dtype=np.int32)
    }
    test_ds = {
        "token_ids": x_test,
        "labels": np.array(y_test, dtype=np.int32)
    }

    # --- Model Definition and Loading ---
    model = BertModel(
        num_layers=args.num_layers,
        num_heads=args.num_heads,
        num_classes=num_classes,
        max_seq_len=256 # Set to match IMDB training
    )

    def load_model_params(filepath, model, rng):
        dummy_input = jnp.zeros((1, model.max_seq_len), dtype=jnp.int32)
        init_rng, dropout_rng = random.split(rng)
        params_struct = model.init({'params': init_rng, 'dropout': dropout_rng}, dummy_input)['params']
        
        with open(filepath, "rb") as f:
            loaded_params = from_bytes(params_struct, f.read())
        return loaded_params

    print(f"Loading model A from: {args.model_a}")
    rng, model_rng = random.split(rng)
    model_a_params = load_model_params(args.model_a, model, model_rng)
    
    print(f"Loading model B from: {args.model_b}")
    rng, model_rng = random.split(rng)
    model_b_params = load_model_params(args.model_b, model, model_rng)

    # Get activations for data-dependent matching
    activation_batch_size = 512
    rng, activation_rng = random.split(rng)
    print(f"Computing activations from model A on a batch of {activation_batch_size} training examples...")
    activations_a = get_mha_inputs(model, model_a_params, train_ds, activation_rng, activation_batch_size, vocab=None, pad_value=PAD)
    activations_b = get_mha_inputs(model, model_b_params, train_ds, activation_rng, activation_batch_size, vocab=None, pad_value=PAD)
    print(f"Got {len(activations_a)} sets of activations, for each of the {args.num_layers} layers.")

    # --- Interpolation and Evaluation ---
    lambdas = jnp.linspace(0, 1, num=11)
    all_results = {}
    
    # 1. Naive Interpolation
    naive_results = compute_interpolation(
        model_a_params, model_b_params, model.apply, lambdas, train_ds, test_ds, args.batch_size, PAD, num_devices, per_device_batch_size, desc="Naive Interpolation"
    )
    all_results["Naive"] = naive_results
    print("\n--- Naive Interpolation Results ---")
    print(json.dumps({"Naive": naive_results}, indent=2))

    # 2. Weight Matching on Attention Layers
    finetune_layer_indices = list(range(args.num_layers)) if args.finetune_layer_which=="all" else [int(idx) for idx in args.finetune_layer_which.split(",")]
    print(f"\nPerforming attention weight matching on layers: {finetune_layer_indices}")
    
    rng, matching_rng = random.split(rng)
    # --- MODIFICATION: Conditional call to matching function ---
    if not args.rope_use:
        aligned_models = matching_attn(
            matching_rng, model_a_params, model_b_params, activations_a, activations_b, finetune_layer_indices, args.num_heads, args.plot_path
        )
    else:
        aligned_models = matching_attn_rope(
            model_a_params, model_b_params, activations_a, activations_b, finetune_layer_indices, args.num_heads
        )

    for method, model_b_aligned in aligned_models.items():
        method_results = compute_interpolation(
            model_a_params, model_b_aligned, model.apply, lambdas, train_ds, test_ds, args.batch_size, PAD, num_devices, per_device_batch_size, desc=f"{method} Interpolation"
        )
        all_results[method] = method_results
        print(f"\n--- {method} Interpolation Results ---")
        print(json.dumps({method: method_results}, indent=2))

    # Save all results to a JSON file
    results_path = os.path.join(args.plot_path, "interpolation_results.json")
    with open(results_path, "w") as f:
        json.dump(all_results, f, indent=2)
    print(f"\nAll results saved to {results_path}")

    # --- Plotting ---
    plt.rcParams.update({
        "font.family": "serif", 'legend.frameon': False, 'lines.linewidth': 2, 'font.size': 13,
        'axes.labelsize': 16, 'xtick.labelsize': 11, 'ytick.labelsize': 11, 'legend.fontsize': 11,
    })
    plt.style.use('tableau-colorblind10')

    lambda_values = lambdas
    metrics = ["Train Loss", "Test Loss", "Train Acc", "Test Acc"]
    positions = [(0, 0), (0, 1), (1, 0), (1, 1)]
    
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))

    for metric, pos in zip(metrics, positions):
        row, col = pos
        ax = axs[row, col]
        for method, results in all_results.items():
            ax.plot(lambda_values, results[metric], label=method)
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(["Model A", r"$\lambda$", "Model B"])
        ax.set_ylabel(metric)
        ax.grid(True, linestyle='--', alpha=0.6)
        if row == 0 and col == 0: ax.legend(loc='best')

    plt.tight_layout()
    save_path = os.path.join(args.plot_path, "interpolation_plots.png")
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Combined plot saved to {save_path}")

    fig, axs = plt.subplots(2, 2, figsize=(12, 10))

    for metric, pos in zip(metrics, positions):
        row, col = pos
        ax = axs[row, col]
        for method, results in all_results.items():
            if method == "Naive":
                continue
            ax.plot(lambda_values, results[metric], label=method)
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(["Model A", r"$\lambda$", "Model B"])
        ax.set_ylabel(metric)
        ax.grid(True, linestyle='--', alpha=0.6)
        if row == 0 and col == 0: ax.legend(loc='best')

    plt.tight_layout()
    save_path = os.path.join(args.plot_path, "interpolation_plots_no_naive.png")
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Plot without naive baseline saved to {save_path}")
    
if __name__ == "__main__":
    main()