from flax.core import freeze, unfreeze
import argparse
import jax.numpy as jnp
from jax import random, vmap
import jax
from tqdm import tqdm
from flax.serialization import from_bytes
import json
import numpy as np
import matplotlib.pyplot as plt
import os
import copy

#from .cifar100_vit_train import ViTModel, make_stuff
from src.datasets import load_cifar100
from src.utils import lerp
from src.matching_utils import matching_attn_all_heads_permu

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--rope-use", action='store_true', help="Use rope if this flag is present")
    parser.add_argument("--model-a", type=str, required=True, help="Path to first fine-tuned ViT model checkpoint")
    parser.add_argument("--model-b", type=str, required=True, help="Path to second fine-tuned ViT model checkpoint")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    parser.add_argument("--num-layers", type=int, required=True, help="Number of transformer layers")
    parser.add_argument("--finetune-layer-which", type=str, required=True, help="Which attention layer to match (single layer)")
    parser.add_argument("--num-heads", type=int, required=True, help="Number of attention heads")
    parser.add_argument("--plot-path", type=str, default="./plots", help="Path to save plots and results")
    args = parser.parse_args()

    # Conditional imports based on rope-use argument
    if not args.rope_use:
        from .cifar100_vit_train import ViTModel, make_stuff
    else:
        # Assumes a RoPE-enabled version of the training script exists
        from .cifar100_vit_train_rope import ViTModel, make_stuff

    # --- Configuration ---
    class Config:
        pass
    config = Config()
    config.rope_use = args.rope_use
    config.model_a = args.model_a
    config.model_b = args.model_b
    config.seed = args.seed
    config.num_layers = args.num_layers
    config.finetune_layer_which = [int(idx) for idx in args.finetune_layer_which.split(",")]
    assert len(config.finetune_layer_which) == 1, "finetune_layer_which must have exactly one layer for this script"
    config.num_heads = args.num_heads
    config.plot_path = args.plot_path
    os.makedirs(config.plot_path, exist_ok=True)
    batch_size = 5000
    
    # --- Model and Data Loading ---
    model = ViTModel(num_layers=config.num_layers, num_heads=config.num_heads)
    stuff = make_stuff(model)

    def load_model(filepath):
        with open(filepath, "rb") as fh:
            return from_bytes(
                model.init(random.PRNGKey(0), jnp.zeros((1, 32, 32, 3)))["params"],
                fh.read()
            )
    model_a = load_model(args.model_a)
    model_b = load_model(args.model_b)

    train_ds, test_ds = load_cifar100(data_dir="/root/log/cifar100/data") 
    lambdas = jnp.linspace(0, 1, num=11)
    rng = random.PRNGKey(config.seed)

    # --- Get Activations (Moved Up) ---
    # Activations are needed for data-dependent matching methods
    activation_batch_size = 1024
    rng, activation_rng = random.split(rng)
    print(f"Computing activations from models A and B on a batch of {activation_batch_size} training examples...")
    activations_a = stuff["get_mha_inputs"](model_a, train_ds, activation_rng, activation_batch_size)
    activations_b = stuff["get_mha_inputs"](model_b, train_ds, activation_rng, activation_batch_size)
    print(f"Got activations for {len(activations_a)} layers from each model.")

    # --- Interpolation Function ---
    def compute_interpolation(model_a, model_b_target, lambdas, stuff, train_ds, test_ds, batch_size, desc="Interpolation"):
        metrics = {"Train Loss": [], "Test Loss": [], "Train Acc": [], "Test Acc": []}
        for lam in tqdm(lambdas, desc=desc):
            p_interp = freeze(lerp(lam, unfreeze(model_a), unfreeze(model_b_target)))
            train_loss, train_acc = 0,0#stuff["dataset_loss_and_accuracy"](p_interp, train_ds, batch_size)
            test_loss, test_acc = stuff["dataset_loss_and_accuracy"](p_interp, test_ds, batch_size)
            metrics["Train Loss"].append(float(f"{train_loss:.4f}"))
            metrics["Test Loss"].append(float(f"{test_loss:.4f}"))
            metrics["Train Acc"].append(float(f"{train_acc:.4f}"))
            metrics["Test Acc"].append(float(f"{test_acc:.4f}"))
        return metrics

    # --- Run Matching and Interpolations ---
    # 1. Naive Interpolation
    naive_results = compute_interpolation(model_a, model_b, lambdas, stuff, train_ds, test_ds, batch_size, desc="Naive Interpolation")

    # 2. Get Aligned Parameters for All Permutations
    print("\nStarting attention head matching for all permutations...")
    params_dict, heads_objective_values, heads_permutation_sol = matching_attn_all_heads_permu(
        rng, model_a, model_b, config.finetune_layer_which, config.num_heads, config.plot_path,
        rope_use=config.rope_use, activations_a=activations_a, activations_b=activations_b
    )

    # 3. Compute Interpolation for All Permutations
    all_perm_results = {}
    for setting in params_dict:
        all_perm_results[setting] = {}
        for perm_str, aligned_result in params_dict[setting].items():
            # Reconstruct the full model B with the aligned attention block
            aligned_attn_block = aligned_result.get('aligned_params', aligned_result)
            params_b_full_aligned = copy.deepcopy(unfreeze(model_b))
            layer_key = f'TransformerEncoderLayer_{config.finetune_layer_which[0]}'
            attention_key = 'MultiHeadDotProductAttention_0'
            params_b_full_aligned[layer_key][attention_key] = aligned_attn_block
            
            # Run interpolation
            results = compute_interpolation(model_a, freeze(params_b_full_aligned), lambdas, stuff, train_ds, test_ds, batch_size, desc=f"{setting} {perm_str}")
            all_perm_results[setting][perm_str] = results

    # --- Consolidate and Save All Results to a Single JSON File ---
    highlighted_perm_str = {}
    for alpha in heads_permutation_sol.keys():
        col_ind = heads_permutation_sol[alpha][1]
        perm_str = str([int(_) for _ in col_ind])
        highlighted_perm_str[alpha] = perm_str

    final_results = {
        "naive_interpolation": naive_results,
        "all_permutation_interpolations": all_perm_results,
        "selected_permutations_by_method": highlighted_perm_str,
        "objective_values": heads_objective_values
    }

    output_file = os.path.join(config.plot_path, "matching_results.json")
    with open(output_file, "w") as f:
        json.dump(final_results, f, indent=2)
    print(f"\nAll results saved to {output_file}")


    # --- Plotting ---
    print("Generating plots...")
    lambda_values = np.array(lambdas)
    metrics = ["Train Loss", "Test Loss", "Train Acc", "Test Acc"]
    positions = [(0, 0), (0, 1), (1, 0), (1, 1)]

    plt.style.use('tableau-colorblind10')

    for setting in ['init_ortho_opt', 'init_ortho_no_opt']:
        if setting not in all_perm_results: continue
        
        for include_naive in [True, False]:
            fig, axs = plt.subplots(2, 2, figsize=(14, 11))
            fig.suptitle(f'Interpolation Plots: {setting}' + (' with Naive' if include_naive else ''), fontsize=16)
            
            for metric, pos in zip(metrics, positions):
                ax = axs[pos]
                
                # Plot all permutations in light gray
                for perm_str, perm_results in all_perm_results[setting].items():
                    ax.plot(lambda_values, perm_results[metric], color='gray', alpha=0.4, linewidth=1.5)
                
                # Add a single "Others" label for clarity
                ax.plot([], [], color='gray', alpha=0.5, label='Other Permutations')

                # Plot naive interpolation if specified
                if include_naive:
                    ax.plot(lambda_values, naive_results[metric], label='Naive', color='blue', linestyle='--')
                
                # Plot highlighted permutations
                colors = plt.cm.viridis(np.linspace(0, 1, len(highlighted_perm_str)))
                for idx, (alpha, perm_str) in enumerate(highlighted_perm_str.items()):
                    if perm_str in all_perm_results[setting]:
                        ax.plot(lambda_values, all_perm_results[setting][perm_str][metric], 
                                label=f"$\\alpha$={alpha}: {perm_str}", color=colors[idx])
                
                ax.set_title(metric)
                ax.set_xlabel(r"$\lambda$")
                ax.set_ylabel("Value")
                ax.grid(True, linestyle='--', alpha=0.6)

            handles, labels = axs[0,0].get_legend_handles_labels()
            fig.legend(handles, labels, loc='upper left')
            plt.tight_layout(rect=[0, 0, 1, 0.96])
            
            plot_name = f"interpolation_{setting}" + ("" if include_naive else "_no_naive") + ".png"
            save_path = os.path.join(config.plot_path, plot_name)
            plt.savefig(save_path)
            plt.close()
    
    print(f"Plots saved in {config.plot_path}")

if __name__ == "__main__":
    main()