from flax.core import freeze, unfreeze
import argparse
import jax.numpy as jnp
from jax import random, lax
import jax
from tqdm import tqdm
from flax.serialization import from_bytes
import json
import numpy as np
import matplotlib.pyplot as plt
import os
import copy
from flax.training.train_state import TrainState
import optax
from flax.jax_utils import replicate, unreplicate
from datasets import Dataset

from src.utils import lerp
from src.matching_utils import matching_attn_all_heads_permu

# --- Evaluation logic for BERT (adapted for pmap) ---
def eval_step(state, batch):
    """A single evaluation step compatible with jax.pmap."""
    logits = state.apply_fn({'params': state.params}, batch['inputs'], batch['pad_mask'],
                            deterministic=True)
    one_hot_labels = jax.nn.one_hot(batch['labels'], num_classes=logits.shape[-1])
    per_example_loss = -jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)
    
    # Use valid_mask to ignore padded examples during metric calculation
    numerator = lax.psum(jnp.sum(per_example_loss * batch['valid_mask']), axis_name='batch')
    denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    loss = numerator / denominator

    per_example_acc = (jnp.argmax(logits, -1) == batch['labels']).astype(jnp.float32)
    acc_numerator = lax.psum(jnp.sum(per_example_acc * batch['valid_mask']), axis_name='batch')
    acc_denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    accuracy = acc_numerator / acc_denominator

    metrics = {'loss': loss, 'accuracy': accuracy}
    return metrics


def main():
    parser = argparse.ArgumentParser(description="Attention head permutation matching for DBpedia BERT models.")
    parser.add_argument("--rope-use", action='store_true', help="Use RoPE-enabled BERT model and matching if this flag is present")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned BERT model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned BERT model checkpoint")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of attention heads")
    parser.add_argument("--embedding-dim", type=int, default=48, help="Dimension of token embeddings")
    parser.add_argument("--hidden-dim", type=int, default=192, help="Dimension of the feed-forward network")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Which attention layer to match (single layer index)")
    parser.add_argument("--plot-path", type=str, default="./plots/dbpedia_head_matching", help="Path to save plots and results")
    parser.add_argument("--train-dataset-path", type=str, required=True, help="Path to train parquet")
    parser.add_argument("--test-dataset-path", type=str, required=True, help="Path to test parquet")
    parser.add_argument("--batch-size", type=int, default=2048, help="Batch size for evaluation")
    args = parser.parse_args()

    # Conditional imports based on rope-use argument
    if not args.rope_use:
        from .dbpedia_bert_train import BertModel, pad_sequences, make_stuff
    else:
        from .dbpedia_bert_train_rope import BertModel, pad_sequences, make_stuff

    # --- Configuration ---
    class Config:
        pass
    config = Config()
    config.rope_use = args.rope_use
    config.seed = args.seed
    config.num_layers = args.num_layers
    config.finetune_layer_which = [int(idx) for idx in args.finetune_layer_which.split(",")]
    assert len(config.finetune_layer_which) == 1, "finetune_layer_which must have exactly one layer for this script"
    config.num_heads = args.num_heads
    config.plot_path = args.plot_path
    os.makedirs(config.plot_path, exist_ok=True)
    
    # --- JAX Device Setup ---
    num_devices = jax.local_device_count()
    print(f"Using {num_devices} devices.")
    assert args.batch_size % num_devices == 0, f"Batch size ({args.batch_size}) must be divisible by the number of devices ({num_devices})."
    per_device_batch_size = args.batch_size // num_devices

    # --- Model and Data Loading ---
    print("Loading DBpedia data...")
    train_dataset = Dataset.from_parquet(args.train_dataset_path)
    test_dataset = Dataset.from_parquet(args.test_dataset_path)
    train_ds = {
        "token_ids": train_dataset['input_ids'],
        "labels": np.array(train_dataset['label'], dtype=np.int32)
    }
    test_ds = {
        "token_ids": test_dataset['input_ids'],
        "labels": np.array(test_dataset['label'], dtype=np.int32)
    }
    num_classes = 219
    PAD = 0

    model = BertModel(
        num_layers=config.num_layers, 
        num_heads=config.num_heads,
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        num_classes=num_classes,
        vocab_size=30522,
        max_seq_len=256
    )
    
    rng = random.PRNGKey(config.seed)

    def load_model_params(filepath, model, rng):
        dummy_input = jnp.zeros((1, model.max_seq_len), dtype=jnp.int32)
        init_rng, dropout_rng = random.split(rng)
        params_struct = model.init({'params': init_rng, 'dropout': dropout_rng}, dummy_input)['params']
        with open(filepath, "rb") as fh:
            return from_bytes(params_struct, fh.read())

    rng, model_rng_a = random.split(rng)
    model_a_params = load_model_params(args.model_a, model, model_rng_a)
    rng, model_rng_b = random.split(rng)
    model_b_params = load_model_params(args.model_b, model, model_rng_b)

    lambdas = jnp.linspace(0, 1, num=9)

    # --- Get Activations ---
    stuff = make_stuff(model)
    get_mha_inputs = stuff["get_mha_inputs"]
    activation_batch_size = 1024
    rng, activation_rng = random.split(rng)
    print(f"Computing activations from models on a batch of {activation_batch_size} training examples...")
    activations_a = get_mha_inputs(model_a_params, train_ds, activation_rng, activation_batch_size)
    activations_b = get_mha_inputs(model_b_params, train_ds, activation_rng, activation_batch_size)
    print(f"Got activations for {len(activations_a)} layers from each model.")

    # --- Interpolation & Evaluation Function ---
    def dataset_loss_and_accuracy(params, model_apply_fn, dataset, batch_size, PAD, num_devices, per_device_batch_size):
        state = TrainState.create(apply_fn=model_apply_fn, params=params, tx=optax.sgd(1e-3)) # optimizer is a dummy
        state = replicate(state)
        p_eval_step = jax.pmap(eval_step, axis_name='batch')
        num_examples = len(dataset['labels'])
        batch_metrics = []
        
        def shard(array):
            return jnp.reshape(array, (num_devices, per_device_batch_size) + array.shape[1:])

        for i in tqdm(range(0, num_examples, batch_size), desc="Evaluating", leave=False):
            end = min(i + batch_size, num_examples)
            current_batch_size = end - i
            pad_size = batch_size - current_batch_size
            
            token_ids = dataset["token_ids"][i:end]
            labels = dataset["labels"][i:end]
            
            if pad_size > 0:
                token_ids += [[PAD]] * pad_size
                labels = np.concatenate([labels, np.zeros(pad_size, dtype=labels.dtype)])

            x_batch = pad_sequences(token_ids, pad_value=PAD)
            x_batch = jnp.array(x_batch)
            pad_mask = (x_batch != PAD)[:, None, None, :]
            labels = jnp.array(labels)
            valid_mask = jnp.array([1.0] * current_batch_size + [0.0] * pad_size)
            
            batch = {
                'inputs': shard(x_batch),
                'pad_mask': shard(pad_mask),
                'labels': shard(labels),
                'valid_mask': shard(valid_mask)
            }
            
            metrics = p_eval_step(state, batch)
            metrics = unreplicate(metrics)
            batch_metrics.append(metrics)

        mean_metrics = {k: np.mean([m[k] for m in batch_metrics]) for k in batch_metrics[0]}
        return mean_metrics['loss'], mean_metrics['accuracy']

    def compute_interpolation(model_a, model_b_target, lambdas, desc="Interpolation"):
        metrics = {"Train Loss": [], "Test Loss": [], "Train Acc": [], "Test Acc": []}
        for lam in tqdm(lambdas, desc=desc):
            p_interp = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b_target)))
            # Skip train evaluation for speed
            train_loss, train_acc = 0.0, 0.0
            test_loss, test_acc = dataset_loss_and_accuracy(
                p_interp, model.apply, test_ds, args.batch_size, PAD, num_devices, per_device_batch_size
            )
            metrics["Train Loss"].append(float(f"{train_loss:.4f}"))
            metrics["Test Loss"].append(float(f"{test_loss:.4f}"))
            metrics["Train Acc"].append(float(f"{train_acc:.4f}"))
            metrics["Test Acc"].append(float(f"{test_acc:.4f}"))
        return metrics

    # --- Run Matching and Interpolations ---
    # 1. Naive Interpolation
    naive_results = compute_interpolation(model_a_params, model_b_params, lambdas, desc="Naive Interpolation")

    # 2. Get Aligned Parameters for All Permutations
    print("\nStarting attention head matching for all permutations...")
    params_dict, heads_objective_values, heads_permutation_sol = matching_attn_all_heads_permu(
        rng, model_a_params, model_b_params, config.finetune_layer_which, config.num_heads, config.plot_path,
        rope_use=config.rope_use, activations_a=activations_a, activations_b=activations_b,
    )

    # 3. Compute Interpolation for All Permutations
    all_perm_results = {}
    for setting in params_dict:
        all_perm_results[setting] = {}
        for perm_str, aligned_result in params_dict[setting].items():
            aligned_attn_block = aligned_result.get('aligned_params', aligned_result)
            params_b_full_aligned = copy.deepcopy(unfreeze(model_b_params))
            
            layer_key = f'TransformerEncoderLayer_{config.finetune_layer_which[0]}'
            attention_key = 'MultiHeadDotProductAttention_0'
            params_b_full_aligned[layer_key][attention_key] = aligned_attn_block
            
            results = compute_interpolation(model_a_params, freeze(params_b_full_aligned), lambdas, desc=f"{setting} {perm_str}")
            all_perm_results[setting][perm_str] = results

    # --- Consolidate and Save All Results ---
    highlighted_perm_str = {}
    for alpha in heads_permutation_sol.keys():
        col_ind = heads_permutation_sol[alpha][1]
        perm_str = str([int(_) for _ in col_ind])
        highlighted_perm_str[alpha] = perm_str

    final_results = {
        "naive_interpolation": naive_results,
        "all_permutation_interpolations": all_perm_results,
        "selected_permutations_by_method": highlighted_perm_str,
        "objective_values": heads_objective_values
    }

    output_file = os.path.join(config.plot_path, "matching_results.json")
    with open(output_file, "w") as f:
        json.dump(final_results, f, indent=2)
    print(f"\nAll results saved to {output_file}")


    # --- Plotting ---
    print("Generating plots...")
    lambda_values = np.array(lambdas)
    metrics = ["Train Loss", "Test Loss", "Train Acc", "Test Acc"]
    positions = [(0, 0), (0, 1), (1, 0), (1, 1)]

    plt.style.use('tableau-colorblind10')

    for setting in ['init_ortho_opt', 'init_ortho_no_opt']:
        if setting not in all_perm_results: continue
        
        for include_naive in [True, False]:
            fig, axs = plt.subplots(2, 2, figsize=(14, 11))
            fig.suptitle(f'DBpedia BERT Interpolation Plots: {setting}' + (' with Naive' if include_naive else ''), fontsize=16)
            
            for metric, pos in zip(metrics, positions):
                ax = axs[pos]
                
                # Handle skipped train evaluation in plots
                if "Train" in metric and all_perm_results[setting]:
                    first_perm = list(all_perm_results[setting].keys())[0]
                    if all(v == 0.0 for v in all_perm_results[setting][first_perm][metric]):
                       ax.text(0.5, 0.5, 'Train evaluation skipped for speed', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, alpha=0.5)
                       ax.set_title(metric)
                       ax.grid(True, linestyle='--', alpha=0.6)
                       continue

                # Plot all permutations in light gray
                for perm_str, perm_results in all_perm_results[setting].items():
                    ax.plot(lambda_values, perm_results[metric], color='gray', alpha=0.4, linewidth=1.5)
                
                ax.plot([], [], color='gray', alpha=0.5, label='Other Permutations') # For legend

                if include_naive:
                    ax.plot(lambda_values, naive_results[metric], label='Naive', color='blue', linestyle='--')
                
                # Plot highlighted permutations
                colors = plt.cm.viridis(np.linspace(0, 1, len(highlighted_perm_str)))
                for idx, (alpha, perm_str) in enumerate(highlighted_perm_str.items()):
                    if perm_str in all_perm_results[setting]:
                        ax.plot(lambda_values, all_perm_results[setting][perm_str][metric], 
                                label=f"$\\alpha$={alpha}: {perm_str}", color=colors[idx])
                
                ax.set_title(metric)
                ax.set_xlabel(r"$\lambda$")
                ax.set_ylabel("Value")
                ax.grid(True, linestyle='--', alpha=0.6)

            handles, labels = axs[1,1].get_legend_handles_labels()
            fig.legend(handles, labels, loc='upper left')
            plt.tight_layout(rect=[0, 0, 1, 0.96])
            
            plot_name = f"interpolation_dbpedia_{setting}" + ("" if include_naive else "_no_naive") + ".png"
            save_path = os.path.join(config.plot_path, plot_name)
            plt.savefig(save_path)
            plt.close()
    
    print(f"Plots saved in {config.plot_path}")

if __name__ == "__main__":
    main()