from flax.core import freeze, unfreeze
import argparse
import jax.numpy as jnp
from jax import random, lax
import jax
from tqdm import tqdm
from flax.serialization import from_bytes
import json
import numpy as np
import matplotlib.pyplot as plt
import os
import copy
from flax.training.train_state import TrainState
import optax

from src.utils import lerp
from src.matching_utils import matching_attn_all_heads_permu

# --- Evaluation logic for BERT (non-pmapped, single device) ---
@jax.jit
def eval_step(state, batch):
    """A single evaluation step for single device."""
    logits = state.apply_fn(
        {'params': state.params},
        batch['inputs'],
        pad_mask=batch['pad_mask'],
        deterministic=True,
    )
    one_hot_labels = jax.nn.one_hot(batch['labels'], num_classes=logits.shape[-1])
    per_example_loss = -jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)
    loss = jnp.mean(per_example_loss)

    per_example_acc = (jnp.argmax(logits, -1) == batch['labels']).astype(jnp.float32)
    accuracy = jnp.mean(per_example_acc)

    return {'loss': loss, 'accuracy': accuracy}

def main():
    parser = argparse.ArgumentParser(description="Attention head permutation matching for IMDB BERT models.")
    parser.add_argument("--rope-use", action='store_true', help="Use RoPE-enabled BERT model and matching if this flag is present")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned BERT model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned BERT model checkpoint")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of attention heads")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Which attention layer to match (single layer index)")
    parser.add_argument("--plot-path", type=str, default="./plots", help="Path to save plots and results")
    parser.add_argument("--data-path", type=str, required=True, help="Path to IMDB CSV data file")
    parser.add_argument("--batch-size", type=int, default=2048, help="Batch size for evaluation")
    args = parser.parse_args()

    # Conditional imports based on rope-use argument
    if not args.rope_use:
        from .imdbreview_bert_train import BertModel, pad_sequences, load_imdb_review_data, get_mha_inputs
    else:
        from .imdbreview_bert_train_rope import BertModel, pad_sequences, load_imdb_review_data, get_mha_inputs

    # --- Configuration ---
    class Config:
        pass
    config = Config()
    config.rope_use = args.rope_use
    config.seed = args.seed
    config.num_layers = args.num_layers
    config.finetune_layer_which = [int(idx) for idx in args.finetune_layer_which.split(",")]
    assert len(config.finetune_layer_which) == 1, "finetune_layer_which must have exactly one layer for this script"
    config.num_heads = args.num_heads
    config.plot_path = args.plot_path
    os.makedirs(config.plot_path, exist_ok=True)
    
    # --- Model and Data Loading ---
    print("Loading IMDB data...")
    (x_train, y_train), (x_test, y_test), num_classes, PAD = load_imdb_review_data(args.data_path, max_seq_len=256)
    train_ds = {"token_ids": x_train, "labels": np.array(y_train, dtype=np.int32)}
    test_ds = {"token_ids": x_test, "labels": np.array(y_test, dtype=np.int32)}

    model = BertModel(
        num_layers=config.num_layers, 
        num_heads=config.num_heads,
        num_classes=num_classes,
        max_seq_len=256
    )
    
    rng = random.PRNGKey(config.seed)

    # Create a dummy TrainState once for efficient param replacement
    dummy_input = jnp.zeros((1, model.max_seq_len), dtype=jnp.int32)
    init_rng, dropout_rng = random.split(rng)
    init_params = model.init({'params': init_rng, 'dropout': dropout_rng}, dummy_input)['params']
    dummy_tx = optax.adam(1e-4)  # Dummy optimizer
    dummy_state = TrainState.create(apply_fn=model.apply, params=init_params, tx=dummy_tx)

    def load_model_params(filepath, model, rng):
        dummy_input = jnp.zeros((1, model.max_seq_len), dtype=jnp.int32)
        init_rng, dropout_rng = random.split(rng)
        params_struct = model.init({'params': init_rng, 'dropout': dropout_rng}, dummy_input)['params']
        with open(filepath, "rb") as fh:
            return from_bytes(params_struct, fh.read())

    rng, model_rng_a = random.split(rng)
    model_a_params = load_model_params(args.model_a, model, model_rng_a)
    rng, model_rng_b = random.split(rng)
    model_b_params = load_model_params(args.model_b, model, model_rng_b)

    lambdas = jnp.linspace(0, 1, num=11)

    # --- Get Activations ---
    activation_batch_size = 512
    rng, activation_rng = random.split(rng)
    print(f"Computing activations from models A and B on a batch of {activation_batch_size} training examples...")
    activations_a = get_mha_inputs(model, model_a_params, train_ds, activation_rng, activation_batch_size, vocab=None, pad_value=PAD)
    activations_b = get_mha_inputs(model, model_b_params, train_ds, activation_rng, activation_batch_size, vocab=None, pad_value=PAD)
    print(f"Got activations for {len(activations_a)} layers from each model.")

    # --- Interpolation & Evaluation Function (Single Device, Efficient) ---
    def dataset_loss_and_accuracy(params, model_apply_fn, dataset, batch_size, PAD):
        """Evaluates the model on a given dataset (single device, param replacement)."""
        # Replace params in the pre-created dummy state (fast, no recreation)
        state_with_params = dummy_state.replace(params=params)
        
        num_examples = len(dataset['labels'])
        total_loss_sum = 0.0
        total_acc_sum = 0.0
        total_count = 0
        
        for i in tqdm(range(0, num_examples, batch_size), desc="Evaluating", leave=False):
            end = min(i + batch_size, num_examples)
            current_bs = end - i
            
            token_ids = dataset["token_ids"][i:end]
            labels_arr = dataset["labels"][i:end]
            
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            x_batch = jnp.array(x_batch)
            pad_mask = (x_batch != PAD)[:, None, None, :]
            labels = jnp.array(labels_arr)
            
            batch = {
                'inputs': x_batch,
                'pad_mask': pad_mask,
                'labels': labels
            }
            
            # Use the jitted eval_step with the replaced state
            metrics = eval_step(state_with_params, batch)
            
            total_loss_sum += metrics['loss'] * current_bs
            total_acc_sum += metrics['accuracy'] * current_bs
            total_count += current_bs

        if total_count > 0:
            mean_loss = total_loss_sum / total_count
            mean_acc = total_acc_sum / total_count
        else:
            mean_loss, mean_acc = 0.0, 0.0
        return mean_loss, mean_acc

    def compute_interpolation(model_a, model_b_target, lambdas, desc="Interpolation"):
        metrics = {"Train Loss": [], "Test Loss": [], "Train Acc": [], "Test Acc": []}
        for lam in tqdm(lambdas, desc=desc):
            p_interp = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b_target)))
            # Skip train evaluation for speed
            train_loss, train_acc = 0.0, 0.0
            test_loss, test_acc = dataset_loss_and_accuracy(
                p_interp, model.apply, test_ds, args.batch_size, PAD
            )
            metrics["Train Loss"].append(float(f"{train_loss:.4f}"))
            metrics["Test Loss"].append(float(f"{test_loss:.4f}"))
            metrics["Train Acc"].append(float(f"{train_acc:.4f}"))
            metrics["Test Acc"].append(float(f"{test_acc:.4f}"))
        return metrics

    # --- Run Matching and Interpolations ---
    # 1. Naive Interpolation
    naive_results = compute_interpolation(model_a_params, model_b_params, lambdas, desc="Naive Interpolation")

    # 2. Get Aligned Parameters for All Permutations
    print("\nStarting attention head matching for all permutations...")
    params_dict, heads_objective_values, heads_permutation_sol = matching_attn_all_heads_permu(
        rng, model_a_params, model_b_params, config.finetune_layer_which, config.num_heads, config.plot_path,
        rope_use=config.rope_use, activations_a=activations_a, activations_b=activations_b,
    )

    # 3. Compute Interpolation for All Permutations
    all_perm_results = {}
    for setting in params_dict:
        all_perm_results[setting] = {}
        for perm_str, aligned_result in params_dict[setting].items():
            aligned_attn_block = aligned_result.get('aligned_params', aligned_result)
            params_b_full_aligned = copy.deepcopy(unfreeze(model_b_params))
            
            layer_key = f'TransformerEncoderLayer_{config.finetune_layer_which[0]}'
            attention_key = 'MultiHeadDotProductAttention_0'
            params_b_full_aligned[layer_key][attention_key] = aligned_attn_block
            
            results = compute_interpolation(model_a_params, freeze(params_b_full_aligned), lambdas, desc=f"{setting} {perm_str}")
            all_perm_results[setting][perm_str] = results

    # --- Consolidate and Save All Results ---
    highlighted_perm_str = {}
    for alpha in heads_permutation_sol.keys():
        col_ind = heads_permutation_sol[alpha][1]
        perm_str = str([int(_) for _ in col_ind])
        highlighted_perm_str[alpha] = perm_str

    final_results = {
        "naive_interpolation": naive_results,
        "all_permutation_interpolations": all_perm_results,
        "selected_permutations_by_method": highlighted_perm_str,
        "objective_values": heads_objective_values
    }

    output_file = os.path.join(config.plot_path, "matching_results.json")
    with open(output_file, "w") as f:
        json.dump(final_results, f, indent=2)
    print(f"\nAll results saved to {output_file}")


    # --- Plotting ---
    print("Generating plots...")
    lambda_values = np.array(lambdas)
    metrics = ["Train Loss", "Test Loss", "Train Acc", "Test Acc"]
    positions = [(0, 0), (0, 1), (1, 0), (1, 1)]

    plt.style.use('tableau-colorblind10')

    for setting in ['init_ortho_opt', 'init_ortho_no_opt']:
        if setting not in all_perm_results: continue
        
        for include_naive in [True, False]:
            fig, axs = plt.subplots(2, 2, figsize=(14, 11))
            fig.suptitle(f'BERT Interpolation Plots: {setting}' + (' with Naive' if include_naive else ''), fontsize=16)
            
            for metric, pos in zip(metrics, positions):
                ax = axs[pos]
                
                # Handle skipped train evaluation in plots
                if "Train" in metric and all_perm_results[setting]:
                    first_perm = list(all_perm_results[setting].keys())[0]
                    if all(v == 0.0 for v in all_perm_results[setting][first_perm][metric]):
                       ax.text(0.5, 0.5, 'Train evaluation skipped for speed', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, alpha=0.5)
                       ax.set_title(metric)
                       continue

                # Plot all permutations in light gray
                for perm_str, perm_results in all_perm_results[setting].items():
                    ax.plot(lambda_values, perm_results[metric], color='gray', alpha=0.4, linewidth=1.5)
                
                ax.plot([], [], color='gray', alpha=0.5, label='Other Permutations') # For legend

                if include_naive:
                    ax.plot(lambda_values, naive_results[metric], label='Naive', color='blue', linestyle='--')
                
                # Plot highlighted permutations
                colors = plt.cm.viridis(np.linspace(0, 1, len(highlighted_perm_str)))
                for idx, (alpha, perm_str) in enumerate(highlighted_perm_str.items()):
                    if perm_str in all_perm_results[setting]:
                        ax.plot(lambda_values, all_perm_results[setting][perm_str][metric], 
                                label=f"$\\alpha$={alpha}: {perm_str}", color=colors[idx])
                
                ax.set_title(metric)
                ax.set_xlabel(r"$\lambda$")
                ax.set_ylabel("Value")
                ax.grid(True, linestyle='--', alpha=0.6)

            handles, labels = axs[1,1].get_legend_handles_labels()
            fig.legend(handles, labels, loc='upper left')
            plt.tight_layout(rect=[0, 0, 1, 0.96])
            
            plot_name = f"interpolation_bert_{setting}" + ("" if include_naive else "_no_naive") + ".png"
            save_path = os.path.join(config.plot_path, plot_name)
            plt.savefig(save_path)
            plt.close()
    
    print(f"Plots saved in {config.plot_path}")

if __name__ == "__main__":
    main()